From 66dc16f96636ae7231ebd7b7dc788837422524a2 Mon Sep 17 00:00:00 2001 From: luopingyi Date: Tue, 14 Oct 2025 10:38:28 +0800 Subject: [PATCH 1/2] init v0.11.0rc0 --- Dockerfile | 4 +- Dockerfile.310p | 61 + Dockerfile.310p.openEuler | 59 + Dockerfile.a3 | 60 + Dockerfile.a3.openEuler | 58 + Dockerfile.openEuler | 58 + README.md | 4 +- README.md.ori | 6 +- README.zh.md | 6 +- benchmarks/ops/ben_vocabparallelembedding.py | 2 +- .../scripts/run-performance-benchmarks.sh | 4 +- benchmarks/tests/serving-tests.json | 3 +- csrc/torch_binding.cpp | 31 +- csrc/torch_binding_meta.cpp | 8 +- docs/source/community/versioning_policy.md | 4 + docs/source/conf.py | 10 +- .../accuracy_report/DeepSeek-V2-Lite.md | 20 + .../accuracy_report/Qwen2.5-VL-7B-Instruct.md | 19 + .../accuracy_report/Qwen3-30B-A3B.md | 21 + .../accuracy_report/Qwen3-8B-Base.md | 21 + .../evaluation/accuracy_report/index.md | 4 + .../modeling/adding_a_new_model.md | 1 - docs/source/faqs.md | 20 +- docs/source/installation.md | 1 + .../configuration/additional_config.po | 8 +- docs/source/tutorials/index.md | 4 + .../tutorials/multi_node_pd_disaggregation.md | 244 ++ docs/source/tutorials/multi_node_qwen3vl.md | 156 + docs/source/tutorials/multi_node_ray.md | 182 ++ docs/source/tutorials/multi_npu_qwen3_next.md | 156 + .../configuration/additional_config.md | 19 +- .../feature_guide/eplb_swift_balancer.md | 94 + .../feature_guide/images/eplb_img.png | Bin 0 -> 56081 bytes docs/source/user_guide/feature_guide/index.md | 1 + .../user_guide/feature_guide/quantization.md | 5 +- docs/source/user_guide/release_notes.md | 65 + examples/disaggregated_prefill_v1/README.md | 16 +- .../disaggregated_prefill_v1/gen_ranktable.py | 70 +- .../disaggregated_prefill_v1/gen_ranktable.sh | 11 +- .../load_balance_proxy_server_example.py | 1 + ...oncake_connector_store_deployment_guide.md | 272 ++ .../external_online_dp/run_dp_template.sh | 2 +- examples/offline_disaggregated_prefill_npu.py | 2 +- examples/offline_weight_load.py | 326 ++ examples/run_dp_server.sh | 2 +- requirements-dev.txt | 2 +- tests/e2e/common.sh | 2 +- tests/e2e/conftest.py | 9 +- tests/e2e/doctests/001-quickstart-test.sh | 6 +- .../002-pip-binary-installation-test.sh | 2 +- tests/e2e/model_utils.py | 7 +- .../e2e/models/configs/DeepSeek-V2-Lite.yaml | 8 +- .../configs/Qwen2.5-VL-7B-Instruct.yaml | 2 + tests/e2e/models/configs/Qwen3-30B-A3B.yaml | 2 + tests/e2e/models/configs/Qwen3-8B-Base.yaml | 2 + tests/e2e/models/configs/accuracy.txt | 1 + tests/e2e/models/report_template.md | 18 +- tests/e2e/models/test_lm_eval_correctness.py | 10 +- tests/e2e/multicard/test_expert_parallel.py | 22 +- .../test_offline_inference_distributed.py | 60 +- tests/e2e/multicard/test_prefix_caching.py | 42 +- tests/e2e/multicard/test_qwen3_moe.py | 1 - .../e2e/multicard/test_torchair_graph_mode.py | 3 + tests/e2e/multicard/test_weight_loader.py | 188 ++ .../e2e/pd_disaggreate/run_edge_case_test.sh | 4 +- tests/e2e/run_doctests.sh | 1 - tests/e2e/singlecard/ops/test_bgmv_expand.py | 4 +- tests/e2e/singlecard/ops/test_bgmv_shrink.py | 2 +- tests/e2e/singlecard/ops/test_fused_moe.py | 104 +- tests/e2e/singlecard/ops/test_moe_comm.py | 175 - .../singlecard/ops/test_rotary_embedding.py | 6 +- .../ops/test_vocabparallelembedding.py | 2 +- .../spec_decode_v1/test_v1_mtp_correctness.py | 28 +- .../test_v1_mtp_torchair_correctness.py | 4 - .../spec_decode_v1/test_v1_spec_decode.py | 4 - tests/e2e/singlecard/test_ascend_scheduler.py | 23 + tests/e2e/singlecard/test_guided_decoding.py | 77 +- .../test_multistream_overlap_shared_expert.py | 103 + tests/e2e/singlecard/test_vlm.py | 19 +- .../vllm_interface/singlecard/test_sampler.py | 36 + tests/e2e/vllm_interface/vllm_test.cfg | 2 + tests/ut/attention/test_attention_v1.py | 108 +- tests/ut/attention/test_mla_v1.py | 42 +- tests/ut/compilation/test_acl_graph.py | 720 +++++ tests/ut/core/test_schedule_config.py | 63 +- tests/ut/core/test_scheduler.py | 610 ++-- .../test_determin_expert_map_all.py | 0 .../test_distributed_tensor_parallel.py | 139 - tests/ut/distributed/test_parallel_state.py | 10 +- .../ut/eplb/adaptor/test_abstract_adaptor.py | 73 + .../eplb/core/policy/test_policy_abstract.py | 31 + .../core/policy/test_policy_dynamic_ep.py | 98 + .../core/policy/test_policy_dynamic_ep_v2.py | 99 + .../ut/eplb/core/policy/test_policy_factor.py | 23 + .../core/test_eplb_device_transfer_loader.py | 122 + tests/ut/eplb/core/test_eplb_utils.py | 79 + .../kv_connector/test_mooncake_connector.py | 71 +- tests/ut/kv_connector/utils.py | 67 +- tests/ut/models/conftest.py | 114 + tests/ut/models/test_deepseek_mtp.py | 13 +- tests/ut/models/test_deepseek_v2.py | 208 +- tests/ut/models/test_qwen2_5_vl.py | 56 + tests/ut/models/test_qwen3_moe.py | 30 - tests/ut/ops/test_activation.py | 13 +- tests/ut/ops/test_comm_utils.py | 98 + tests/ut/ops/test_common_fused_moe.py | 69 +- .../test_fused_moe_prepare_and_finalize.py | 289 ++ tests/ut/ops/test_fused_ops.py | 315 +- tests/ut/ops/test_layernorm.py | 170 +- tests/ut/ops/test_linear.py | 393 +-- tests/ut/ops/test_moe_comm_method.py | 232 ++ tests/ut/ops/test_rotary_embedding.py | 132 +- tests/ut/ops/test_token_dispatcher.py | 228 +- tests/ut/ops/test_vocab_parallel_embedding.py | 14 +- .../worker/patch_common/test_patch_linear.py | 167 - tests/ut/quantization/test_func_wrapper.py | 134 - tests/ut/quantization/test_quant_config.py | 45 +- tests/ut/quantization/test_quantizer.py | 145 - tests/ut/quantization/test_utils.py | 62 + tests/ut/quantization/test_w4a8_dynamic.py | 164 +- tests/ut/quantization/test_w8a8.py | 6 +- tests/ut/quantization/test_w8a8_dynamic.py | 69 + .../sample/logits_processor/test_builtin.py | 40 + tests/ut/test_ascend_config.py | 59 +- tests/ut/test_platform.py | 37 +- tests/ut/test_utils.py | 25 +- .../models/test_torchair_deepseek_mtp.py | 2 - .../models/test_torchair_deepseek_v2.py | 16 +- .../torchair/ops/test_torchair_fused_moe.py | 26 +- .../ops/test_torchair_rotary_embedding.py | 53 +- .../test_torchair_w4a8_dynamic.py | 129 +- tests/ut/torchair/test_torchair_attention.py | 95 + tests/ut/torchair/test_torchair_mla.py | 46 +- tests/ut/torchair/test_utils.py | 13 - tests/ut/worker/test_input_batch.py | 2 +- tests/ut/worker/test_model_runner_v1.py | 107 + tests/ut/worker/test_worker_v1.py | 100 +- vllm_ascend/__init__.py | 2 + vllm_ascend/ascend_config.py | 49 +- vllm_ascend/ascend_forward_context.py | 93 +- vllm_ascend/attention/attention_mask.py | 51 +- vllm_ascend/attention/attention_v1.py | 242 +- vllm_ascend/attention/mla_v1.py | 179 +- vllm_ascend/attention/sfa_v1.py | 986 ++++++ vllm_ascend/attention/utils.py | 46 +- vllm_ascend/compilation/acl_graph.py | 87 +- vllm_ascend/core/schedule_config.py | 42 +- vllm_ascend/core/scheduler.py | 215 +- vllm_ascend/distributed/__init__.py | 5 + .../distributed/cpu_offload_connector.py | 457 +++ .../cpu_offload_manager}/__init__.py | 0 .../cpu_kv_cache_manager.py | 202 ++ .../cpu_offload_manager/metadata.py | 269 ++ .../llmdatadist_c_mgr_connector.py | 189 +- vllm_ascend/distributed/moe_comm_method.py | 556 ---- .../mooncake}/__init__.py | 0 .../distributed/mooncake/config_data.py | 447 +++ .../distributed/mooncake/kv_transfer.py | 251 ++ .../distributed/mooncake/mooncake_engine.py | 489 +++ .../distributed/mooncake/mooncake_store.py | 88 + .../mooncake/mooncake_store_connector_v1.py | 484 +++ vllm_ascend/distributed/mooncake_connector.py | 84 +- vllm_ascend/distributed/parallel_state.py | 27 +- vllm_ascend/distributed/tensor_parallel.py | 248 -- vllm_ascend/envs.py | 27 +- vllm_ascend/eplb/__init__.py | 0 vllm_ascend/eplb/adaptor/__init__.py | 0 vllm_ascend/eplb/adaptor/abstract_adaptor.py | 44 + vllm_ascend/eplb/adaptor/vllm_adaptor.py | 289 ++ vllm_ascend/eplb/core/__init__.py | 0 .../eplb/core/eplb_device_transfer_loader.py | 137 + vllm_ascend/eplb/core/eplb_utils.py | 135 + vllm_ascend/eplb/core/eplb_worker.py | 436 +++ vllm_ascend/eplb/core/policy/__init__.py | 0 .../eplb/core/policy/policy_abstract.py | 42 + .../eplb/core/policy/policy_dynamic_ep.py | 389 +++ .../eplb/core/policy/policy_dynamic_ep_v2.py | 771 +++++ .../eplb/core/policy/policy_factory.py | 33 + .../eplb/core/policy/policy_flashlb.py | 651 ++++ vllm_ascend/eplb/core/policy/policy_random.py | 30 + vllm_ascend/eplb/eplb_updator.py | 205 ++ vllm_ascend/eplb/utils.py | 77 + .../lora/{punica_wrapper => }/lora_ops.py | 25 +- .../lora/{punica_wrapper => }/punica_npu.py | 26 +- vllm_ascend/lora/utils.py | 110 + vllm_ascend/meta_registration.py | 13 +- vllm_ascend/models/__init__.py | 61 +- vllm_ascend/models/deepseek_dbo.py | 1046 ------ vllm_ascend/models/deepseek_mtp.py | 31 +- vllm_ascend/models/deepseek_v2.py | 897 ++---- vllm_ascend/models/deepseek_v3.py | 27 - vllm_ascend/models/layers/__init__.py | 0 vllm_ascend/models/layers/mla.py | 180 ++ vllm_ascend/models/layers/sfa.py | 233 ++ vllm_ascend/models/pangu_moe.py | 1106 ------- vllm_ascend/models/qwen2_5_vl.py | 74 +- .../models/qwen2_5_vl_without_padding.py | 296 +- vllm_ascend/models/qwen2_vl.py | 24 +- vllm_ascend/models/qwen3.py | 156 - vllm_ascend/models/qwen3_moe.py | 147 +- vllm_ascend/models/qwen3_next.py | 676 ++++ vllm_ascend/ops/__init__.py | 18 +- vllm_ascend/ops/activation.py | 2 + vllm_ascend/ops/casual_conv1d.py | 539 ++++ vllm_ascend/ops/common_fused_moe.py | 613 ++-- vllm_ascend/ops/fla.py | 218 ++ vllm_ascend/ops/fused_moe.py | 360 +-- vllm_ascend/ops/layernorm.py | 128 +- vllm_ascend/ops/linear.py | 440 +-- vllm_ascend/ops/linear_op.py | 459 +++ vllm_ascend/ops/moe/__init__.py | 0 vllm_ascend/ops/{ => moe}/comm_utils.py | 53 +- .../ops/{layers => moe}/experts_selector.py | 0 .../ops/moe/fused_moe_prepare_and_finalize.py | 459 +++ vllm_ascend/ops/moe/moe_comm_method.py | 273 ++ vllm_ascend/ops/{layers => moe}/moe_mlp.py | 173 +- .../token_dispatcher.py | 259 +- vllm_ascend/ops/register_custom_ops.py | 201 ++ vllm_ascend/ops/rotary_embedding.py | 104 +- vllm_ascend/ops/sigmoid_gating.py | 384 +++ vllm_ascend/ops/vocab_parallel_embedding.py | 14 + vllm_ascend/patch/__init__.py | 39 +- .../patch/platform/patch_common/__init__.py | 6 + .../platform/patch_common/patch_config.py | 313 ++ .../patch_common/patch_mamba_config.py | 100 + .../patch_common/patch_multimodal_merge.py | 58 + .../patch_common/patch_transformers_utils.py | 200 ++ .../patch/worker/patch_common/__init__.py | 16 +- .../patch_common/patch_attention_layer.py | 202 ++ .../patch_common/patch_attention_selector.py | 181 ++ .../patch_common/patch_attentionspec.py | 110 + .../patch/worker/patch_common/patch_linear.py | 147 - .../patch_common/patch_lora_embedding.py | 29 - .../patch/worker/patch_common/patch_triton.py | 16 + .../patch_common/patch_weight_loader.py | 44 + vllm_ascend/platform.py | 138 +- vllm_ascend/quantization/func_wrapper.py | 184 -- vllm_ascend/quantization/quant_config.py | 140 +- vllm_ascend/quantization/quantizer.py | 311 -- vllm_ascend/quantization/utils.py | 83 + vllm_ascend/quantization/w4a8_dynamic.py | 131 +- vllm_ascend/quantization/w8a8.py | 2 +- vllm_ascend/quantization/w8a8_dynamic.py | 212 +- .../sample/logits_processor/__init__.py | 50 + .../sample/logits_processor/builtin.py | 35 + vllm_ascend/sample/sampler.py | 24 +- vllm_ascend/spec_decode/__init__.py | 33 + vllm_ascend/spec_decode/eagle_proposer.py | 674 ++++ vllm_ascend/spec_decode/interface.py | 51 + vllm_ascend/spec_decode/mtp_proposer.py | 657 ++++ vllm_ascend/spec_decode/ngram_proposer.py | 65 + vllm_ascend/torchair/models/qwen2.py | 7 +- vllm_ascend/torchair/models/qwen3_moe.py | 19 +- .../torchair/models/torchair_deepseek_mtp.py | 6 +- .../torchair/models/torchair_deepseek_v2.py | 315 +- .../torchair/models/torchair_pangu_moe.py | 19 +- .../{ => torchair}/ops/sequence_parallel.py | 0 .../torchair/ops/shared_weight_layer.py | 245 ++ .../ops/torchair_activation.py} | 62 +- .../torchair/ops/torchair_fused_moe.py | 189 +- .../torchair/ops/torchair_layernorm.py | 51 + .../torchair/ops/torchair_rotary_embedding.py | 23 +- .../quantization/torchair_quantizer.py | 29 - .../quantization/torchair_w4a8_dynamic.py | 107 +- .../quantization/torchair_w8a8_dynamic.py | 35 +- vllm_ascend/torchair/torchair_attention.py | 45 +- vllm_ascend/torchair/torchair_mla.py | 224 +- vllm_ascend/torchair/torchair_model_runner.py | 173 +- vllm_ascend/torchair/torchair_sfa.py | 1330 ++++++++ vllm_ascend/torchair/torchair_worker.py | 44 +- vllm_ascend/torchair/utils.py | 35 +- vllm_ascend/utils.py | 188 +- vllm_ascend/worker/block_table.py | 312 ++ vllm_ascend/worker/eagle_proposer_v1.py | 398 --- vllm_ascend/worker/model_runner_v1.py | 2816 +++++++++++------ vllm_ascend/worker/mtp_proposer_v1.py | 439 --- vllm_ascend/worker/npu_input_batch.py | 100 +- vllm_ascend/worker/worker_v1.py | 124 +- 278 files changed, 28130 insertions(+), 11708 deletions(-) create mode 100644 Dockerfile.310p create mode 100644 Dockerfile.310p.openEuler create mode 100644 Dockerfile.a3 create mode 100644 Dockerfile.a3.openEuler create mode 100644 Dockerfile.openEuler create mode 100644 docs/source/developer_guide/evaluation/accuracy_report/DeepSeek-V2-Lite.md create mode 100644 docs/source/developer_guide/evaluation/accuracy_report/Qwen2.5-VL-7B-Instruct.md create mode 100644 docs/source/developer_guide/evaluation/accuracy_report/Qwen3-30B-A3B.md create mode 100644 docs/source/developer_guide/evaluation/accuracy_report/Qwen3-8B-Base.md create mode 100644 docs/source/tutorials/multi_node_pd_disaggregation.md create mode 100644 docs/source/tutorials/multi_node_qwen3vl.md create mode 100644 docs/source/tutorials/multi_node_ray.md create mode 100644 docs/source/tutorials/multi_npu_qwen3_next.md create mode 100644 docs/source/user_guide/feature_guide/eplb_swift_balancer.md create mode 100644 docs/source/user_guide/feature_guide/images/eplb_img.png create mode 100644 examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md create mode 100644 examples/offline_weight_load.py create mode 100644 tests/e2e/multicard/test_weight_loader.py delete mode 100644 tests/e2e/singlecard/ops/test_moe_comm.py create mode 100644 tests/e2e/singlecard/test_multistream_overlap_shared_expert.py create mode 100644 tests/e2e/vllm_interface/singlecard/test_sampler.py create mode 100644 tests/e2e/vllm_interface/vllm_test.cfg create mode 100644 tests/ut/compilation/test_acl_graph.py rename vllm_ascend/lora/punica_wrapper/__init__.py => tests/ut/distributed/test_determin_expert_map_all.py (100%) delete mode 100644 tests/ut/distributed/test_distributed_tensor_parallel.py create mode 100644 tests/ut/eplb/adaptor/test_abstract_adaptor.py create mode 100644 tests/ut/eplb/core/policy/test_policy_abstract.py create mode 100644 tests/ut/eplb/core/policy/test_policy_dynamic_ep.py create mode 100644 tests/ut/eplb/core/policy/test_policy_dynamic_ep_v2.py create mode 100644 tests/ut/eplb/core/policy/test_policy_factor.py create mode 100644 tests/ut/eplb/core/test_eplb_device_transfer_loader.py create mode 100644 tests/ut/eplb/core/test_eplb_utils.py create mode 100644 tests/ut/models/conftest.py create mode 100644 tests/ut/ops/test_comm_utils.py create mode 100644 tests/ut/ops/test_fused_moe_prepare_and_finalize.py create mode 100644 tests/ut/ops/test_moe_comm_method.py delete mode 100644 tests/ut/patch/worker/patch_common/test_patch_linear.py delete mode 100644 tests/ut/quantization/test_func_wrapper.py delete mode 100644 tests/ut/quantization/test_quantizer.py create mode 100644 tests/ut/quantization/test_utils.py create mode 100644 tests/ut/quantization/test_w8a8_dynamic.py create mode 100644 tests/ut/sample/logits_processor/test_builtin.py create mode 100644 tests/ut/torchair/test_torchair_attention.py create mode 100644 tests/ut/worker/test_model_runner_v1.py create mode 100644 vllm_ascend/attention/sfa_v1.py create mode 100644 vllm_ascend/distributed/cpu_offload_connector.py rename vllm_ascend/{ops/layers => distributed/cpu_offload_manager}/__init__.py (100%) create mode 100644 vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py create mode 100644 vllm_ascend/distributed/cpu_offload_manager/metadata.py delete mode 100644 vllm_ascend/distributed/moe_comm_method.py rename vllm_ascend/{ops/moe_dispatcher => distributed/mooncake}/__init__.py (100%) create mode 100644 vllm_ascend/distributed/mooncake/config_data.py create mode 100644 vllm_ascend/distributed/mooncake/kv_transfer.py create mode 100644 vllm_ascend/distributed/mooncake/mooncake_engine.py create mode 100644 vllm_ascend/distributed/mooncake/mooncake_store.py create mode 100644 vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py delete mode 100644 vllm_ascend/distributed/tensor_parallel.py create mode 100644 vllm_ascend/eplb/__init__.py create mode 100644 vllm_ascend/eplb/adaptor/__init__.py create mode 100644 vllm_ascend/eplb/adaptor/abstract_adaptor.py create mode 100644 vllm_ascend/eplb/adaptor/vllm_adaptor.py create mode 100644 vllm_ascend/eplb/core/__init__.py create mode 100644 vllm_ascend/eplb/core/eplb_device_transfer_loader.py create mode 100644 vllm_ascend/eplb/core/eplb_utils.py create mode 100644 vllm_ascend/eplb/core/eplb_worker.py create mode 100644 vllm_ascend/eplb/core/policy/__init__.py create mode 100644 vllm_ascend/eplb/core/policy/policy_abstract.py create mode 100644 vllm_ascend/eplb/core/policy/policy_dynamic_ep.py create mode 100644 vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py create mode 100644 vllm_ascend/eplb/core/policy/policy_factory.py create mode 100644 vllm_ascend/eplb/core/policy/policy_flashlb.py create mode 100644 vllm_ascend/eplb/core/policy/policy_random.py create mode 100644 vllm_ascend/eplb/eplb_updator.py create mode 100644 vllm_ascend/eplb/utils.py rename vllm_ascend/lora/{punica_wrapper => }/lora_ops.py (78%) rename vllm_ascend/lora/{punica_wrapper => }/punica_npu.py (94%) create mode 100644 vllm_ascend/lora/utils.py delete mode 100644 vllm_ascend/models/deepseek_dbo.py create mode 100644 vllm_ascend/models/layers/__init__.py create mode 100644 vllm_ascend/models/layers/mla.py create mode 100644 vllm_ascend/models/layers/sfa.py delete mode 100644 vllm_ascend/models/pangu_moe.py delete mode 100644 vllm_ascend/models/qwen3.py create mode 100644 vllm_ascend/models/qwen3_next.py create mode 100644 vllm_ascend/ops/casual_conv1d.py create mode 100644 vllm_ascend/ops/fla.py create mode 100644 vllm_ascend/ops/linear_op.py create mode 100644 vllm_ascend/ops/moe/__init__.py rename vllm_ascend/ops/{ => moe}/comm_utils.py (55%) rename vllm_ascend/ops/{layers => moe}/experts_selector.py (100%) create mode 100644 vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py create mode 100644 vllm_ascend/ops/moe/moe_comm_method.py rename vllm_ascend/ops/{layers => moe}/moe_mlp.py (51%) rename vllm_ascend/ops/{moe_dispatcher => moe}/token_dispatcher.py (75%) create mode 100644 vllm_ascend/ops/register_custom_ops.py create mode 100644 vllm_ascend/ops/sigmoid_gating.py create mode 100644 vllm_ascend/patch/platform/patch_common/patch_config.py create mode 100644 vllm_ascend/patch/platform/patch_common/patch_mamba_config.py create mode 100644 vllm_ascend/patch/platform/patch_common/patch_multimodal_merge.py create mode 100644 vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_attention_layer.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_attention_selector.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_attentionspec.py delete mode 100644 vllm_ascend/patch/worker/patch_common/patch_linear.py delete mode 100644 vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_triton.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_weight_loader.py delete mode 100644 vllm_ascend/quantization/func_wrapper.py delete mode 100644 vllm_ascend/quantization/quantizer.py create mode 100644 vllm_ascend/quantization/utils.py create mode 100644 vllm_ascend/sample/logits_processor/__init__.py create mode 100644 vllm_ascend/sample/logits_processor/builtin.py create mode 100644 vllm_ascend/spec_decode/__init__.py create mode 100644 vllm_ascend/spec_decode/eagle_proposer.py create mode 100644 vllm_ascend/spec_decode/interface.py create mode 100644 vllm_ascend/spec_decode/mtp_proposer.py create mode 100644 vllm_ascend/spec_decode/ngram_proposer.py rename vllm_ascend/{ => torchair}/ops/sequence_parallel.py (100%) create mode 100644 vllm_ascend/torchair/ops/shared_weight_layer.py rename vllm_ascend/{distributed/communication_op.py => torchair/ops/torchair_activation.py} (52%) create mode 100644 vllm_ascend/torchair/ops/torchair_layernorm.py delete mode 100644 vllm_ascend/torchair/quantization/torchair_quantizer.py create mode 100644 vllm_ascend/torchair/torchair_sfa.py create mode 100644 vllm_ascend/worker/block_table.py delete mode 100644 vllm_ascend/worker/eagle_proposer_v1.py delete mode 100644 vllm_ascend/worker/mtp_proposer_v1.py diff --git a/Dockerfile b/Dockerfile index cabe38c..1d0b73b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM git.modelhub.org.cn:9443/enginex-ascend/cann:8.2.rc1-910b-ubuntu22.04-py3.11 +FROM quay.io/ascend/cann:8.2.rc1-910b-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 @@ -37,7 +37,7 @@ RUN pip config set global.index-url ${PIP_INDEX_URL} # Install vLLM ARG VLLM_REPO=https://github.com/vllm-project/vllm.git -ARG VLLM_TAG=v0.10.1.1 +ARG VLLM_TAG=v0.11.0rc3 RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /vllm-workspace/vllm # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /vllm-workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ diff --git a/Dockerfile.310p b/Dockerfile.310p new file mode 100644 index 0000000..f5ec94f --- /dev/null +++ b/Dockerfile.310p @@ -0,0 +1,61 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +FROM quay.io/ascend/cann:8.2.rc1-310p-ubuntu22.04-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +# Define environments +ENV DEBIAN_FRONTEND=noninteractive +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN apt-get update -y && \ + apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev && \ + rm -rf /var/cache/apt/* && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY . /vllm-workspace/vllm-ascend/ + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +# Install vLLM +ARG VLLM_REPO=https://github.com/vllm-project/vllm.git +ARG VLLM_TAG=v0.11.0rc3 +RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /vllm-workspace/vllm +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /vllm-workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton && \ + python3 -m pip cache purge + +# Install vllm-ascend +# Append `libascend_hal.so` path (devlib) to LD_LIBRARY_PATH +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + export SOC_VERSION=ASCEND310P3 && \ + python3 -m pip install -v -e /vllm-workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip cache purge + +# Install modelscope (for fast download) and ray (for multinode) +RUN python3 -m pip install modelscope 'ray>=2.47.1' 'protobuf>3.20.0' && \ + python3 -m pip cache purge + +CMD ["/bin/bash"] diff --git a/Dockerfile.310p.openEuler b/Dockerfile.310p.openEuler new file mode 100644 index 0000000..3e9a2da --- /dev/null +++ b/Dockerfile.310p.openEuler @@ -0,0 +1,59 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +FROM quay.io/ascend/cann:8.2.rc1-310p-openeuler24.03-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN yum update -y && \ + yum install -y python3-pip git vim wget net-tools gcc gcc-c++ make cmake numactl-devel && \ + rm -rf /var/cache/yum + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +WORKDIR /workspace + +COPY . /vllm-workspace/vllm-ascend/ + +# Install vLLM +ARG VLLM_REPO=https://github.com/vllm-project/vllm.git +ARG VLLM_TAG=v0.11.0rc3 + +RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /vllm-workspace/vllm +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /vllm-workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton && \ + python3 -m pip cache purge + +# Install vllm-ascend +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/include/c++/12:/usr/include/c++/12/`uname -i`-openEuler-linux && \ + export SOC_VERSION=ASCEND310P3 && \ + python3 -m pip install -v -e /vllm-workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip cache purge + +# Install modelscope (for fast download) and ray (for multinode) +RUN python3 -m pip install modelscope 'ray>=2.47.1' 'protobuf>3.20.0' && \ + python3 -m pip cache purge + +CMD ["/bin/bash"] diff --git a/Dockerfile.a3 b/Dockerfile.a3 new file mode 100644 index 0000000..de01698 --- /dev/null +++ b/Dockerfile.a3 @@ -0,0 +1,60 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +FROM quay.io/ascend/cann:8.2.rc1-a3-ubuntu22.04-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +# Define environments +ENV DEBIAN_FRONTEND=noninteractive +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN apt-get update -y && \ + apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev && \ + rm -rf /var/cache/apt/* && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY . /vllm-workspace/vllm-ascend/ + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +# Install vLLM +ARG VLLM_REPO=https://github.com/vllm-project/vllm.git +ARG VLLM_TAG=v0.11.0rc3 +RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /vllm-workspace/vllm +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /vllm-workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton && \ + python3 -m pip cache purge + +# Install vllm-ascend +# Append `libascend_hal.so` path (devlib) to LD_LIBRARY_PATH +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v -e /vllm-workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip cache purge + +# Install modelscope (for fast download) and ray (for multinode) +RUN python3 -m pip install modelscope 'ray>=2.47.1' 'protobuf>3.20.0' && \ + python3 -m pip cache purge + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/Dockerfile.a3.openEuler b/Dockerfile.a3.openEuler new file mode 100644 index 0000000..cec4ab6 --- /dev/null +++ b/Dockerfile.a3.openEuler @@ -0,0 +1,58 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +FROM quay.io/ascend/cann:8.2.rc1-a3-openeuler24.03-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN yum update -y && \ + yum install -y python3-pip git vim wget net-tools gcc gcc-c++ make cmake numactl-devel && \ + rm -rf /var/cache/yum + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +WORKDIR /workspace + +COPY . /vllm-workspace/vllm-ascend/ + +# Install vLLM +ARG VLLM_REPO=https://github.com/vllm-project/vllm.git +ARG VLLM_TAG=v0.11.0rc3 + +RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /vllm-workspace/vllm +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /vllm-workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton && \ + python3 -m pip cache purge + +# Install vllm-ascend +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/include/c++/12:/usr/include/c++/12/`uname -i`-openEuler-linux && \ + python3 -m pip install -v -e /vllm-workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip cache purge + +# Install modelscope (for fast download) and ray (for multinode) +RUN python3 -m pip install modelscope 'ray>=2.47.1' 'protobuf>3.20.0' && \ + python3 -m pip cache purge + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/Dockerfile.openEuler b/Dockerfile.openEuler new file mode 100644 index 0000000..14b6cce --- /dev/null +++ b/Dockerfile.openEuler @@ -0,0 +1,58 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +FROM quay.io/ascend/cann:8.2.rc1-910b-openeuler24.03-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN yum update -y && \ + yum install -y python3-pip git vim wget net-tools gcc gcc-c++ make cmake numactl-devel && \ + rm -rf /var/cache/yum + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +WORKDIR /workspace + +COPY . /vllm-workspace/vllm-ascend/ + +# Install vLLM +ARG VLLM_REPO=https://github.com/vllm-project/vllm.git +ARG VLLM_TAG=v0.11.0rc3 + +RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /vllm-workspace/vllm +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /vllm-workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton && \ + python3 -m pip cache purge + +# Install vllm-ascend +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/include/c++/12:/usr/include/c++/12/`uname -i`-openEuler-linux && \ + python3 -m pip install -v -e /vllm-workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip cache purge + +# Install modelscope (for fast download) and ray (for multinode) +RUN python3 -m pip install modelscope 'ray>=2.47.1' 'protobuf>3.20.0' && \ + python3 -m pip cache purge + +CMD ["/bin/bash"] diff --git a/README.md b/README.md index 7bd91ef..918e5d7 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ## 镜像 -Latest RC Version: git.modelhub.org.cn:9443/enginex-ascend/vllm-ascend:v0.10.0rc1 +Latest RC Version: git.modelhub.org.cn:9443/enginex-ascend/vllm-ascend:v0.11.0rc0 ## 总览 @@ -78,4 +78,4 @@ curl -X POST http://localhost:10086/v1/chat/completions \ | Version | Release type | Doc | |------------|--------------|--------------------------------------| |v0.10.1rc1| 最新RC版本 |请查看[快速开始](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)和[安装指南](https://vllm-ascend.readthedocs.io/en/latest/installation.html)了解更多| -|v0.9.1| 最新正式/稳定版本 |[快速开始](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/quick_start.html) and [安装指南](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/installation.html)了解更多| \ No newline at end of file +|v0.9.1| 最新正式/稳定版本 |[快速开始](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/quick_start.html) and [安装指南](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/installation.html)了解更多| diff --git a/README.md.ori b/README.md.ori index 72ed323..9c255b1 100644 --- a/README.md.ori +++ b/README.md.ori @@ -42,7 +42,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l - OS: Linux - Software: * Python >= 3.9, < 3.12 - * CANN >= 8.2.rc1 + * CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html)) * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM (the same version as vllm-ascend) @@ -52,7 +52,7 @@ Please use the following recommended versions to get started quickly: | Version | Release type | Doc | |------------|--------------|--------------------------------------| -|v0.10.1rc1|Latest release candidate|[QuickStart](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html) and [Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more details| +|v0.11.0rc0|Latest release candidate|[QuickStart](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html) and [Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more details| |v0.9.1|Latest stable version|[QuickStart](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/quick_start.html) and [Installation](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/installation.html) for more details| ## Contributing @@ -73,7 +73,7 @@ Below is maintained branches: | Branch | Status | Note | |------------|--------------|--------------------------------------| -| main | Maintained | CI commitment for vLLM main branch and vLLM 0.10.x branch | +| main | Maintained | CI commitment for vLLM main branch and vLLM v0.11.0 tag | | v0.7.1-dev | Unmaintained | Only doc fixed is allowed | | v0.7.3-dev | Maintained | CI commitment for vLLM 0.7.3 version, only bug fix is allowed and no new release tag any more. | | v0.9.1-dev | Maintained | CI commitment for vLLM 0.9.1 version | diff --git a/README.zh.md b/README.zh.md index d7f1310..bb7ddb9 100644 --- a/README.zh.md +++ b/README.zh.md @@ -43,7 +43,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP - 操作系统:Linux - 软件: * Python >= 3.9, < 3.12 - * CANN >= 8.2.rc1 + * CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html)) * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM (与vllm-ascend版本一致) @@ -53,7 +53,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP | Version | Release type | Doc | |------------|--------------|--------------------------------------| -|v0.10.1rc1| 最新RC版本 |请查看[快速开始](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)和[安装指南](https://vllm-ascend.readthedocs.io/en/latest/installation.html)了解更多| +|v0.11.0rc0| 最新RC版本 |请查看[快速开始](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)和[安装指南](https://vllm-ascend.readthedocs.io/en/latest/installation.html)了解更多| |v0.9.1| 最新正式/稳定版本 |[快速开始](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/quick_start.html) and [安装指南](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/installation.html)了解更多| ## 贡献 @@ -73,7 +73,7 @@ vllm-ascend有主干分支和开发分支。 | 分支 | 状态 | 备注 | |------------|------------|---------------------| -| main | Maintained | 基于vLLM main分支CI看护 | +| main | Maintained | 基于vLLM main分支和vLLM最新版本(v0.11.0)CI看护 | | v0.7.1-dev | Unmaintained | 只允许文档修复 | | v0.7.3-dev | Maintained | 基于vLLM v0.7.3版本CI看护, 只允许Bug修复,不会再发布新版本 | | v0.9.1-dev | Maintained | 基于vLLM v0.9.1版本CI看护 | diff --git a/benchmarks/ops/ben_vocabparallelembedding.py b/benchmarks/ops/ben_vocabparallelembedding.py index b3ef7ec..5590c73 100644 --- a/benchmarks/ops/ben_vocabparallelembedding.py +++ b/benchmarks/ops/ben_vocabparallelembedding.py @@ -112,7 +112,7 @@ def test_get_masked_input_and_mask( # Define custom function def custom_fn(): - return torch.ops._C.get_masked_input_and_mask( + return torch.ops._C_ascend.get_masked_input_and_mask( input_tensor, test_case["org_start"], test_case["org_end"], diff --git a/benchmarks/scripts/run-performance-benchmarks.sh b/benchmarks/scripts/run-performance-benchmarks.sh index b604fe9..befdf69 100644 --- a/benchmarks/scripts/run-performance-benchmarks.sh +++ b/benchmarks/scripts/run-performance-benchmarks.sh @@ -78,7 +78,9 @@ kill_npu_processes() { ps -aux lsof -t -i:8000 | xargs -r kill -9 pgrep python3 | xargs -r kill -9 - + # vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445 + pgrep VLLM | xargs -r kill -9 + sleep 4 rm -rf ~/.config/vllm diff --git a/benchmarks/tests/serving-tests.json b/benchmarks/tests/serving-tests.json index 6398710..c2be9eb 100644 --- a/benchmarks/tests/serving-tests.json +++ b/benchmarks/tests/serving-tests.json @@ -23,7 +23,8 @@ "hf_split": "train", "endpoint": "/v1/chat/completions", "dataset_path": "lmarena-ai/vision-arena-bench-v0.1", - "num_prompts": 200 + "num_prompts": 200, + "no_stream": "" } }, { diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 375ef59..5dd6988 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include "acl/acl.h" #include "ops.h" #include "utils.h" @@ -142,7 +141,7 @@ std::tuple get_masked_input_and_mask( TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | - index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | Parameters: org_vocab_start_index //base embeddings start org_vocab_end_index //base embeddings end @@ -165,22 +164,22 @@ std::tuple get_masked_input_and_mask( // Create output tensors at::Tensor masked_input = at::empty_like(input); at::Tensor mask = at::empty_like(input).to(at::kBool); - + // Get data pointers void *input_ptr = input.data_ptr(); void *masked_input_ptr = masked_input.data_ptr(); void *mask_ptr = mask.data_ptr(); - + // Get current stream aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - + // Get scalar type at::ScalarType scalar_type = input.scalar_type(); - + // Create and configure OpCommand at_npu::native::OpCommand cmd; cmd.Name("get_masked_input_and_mask"); - cmd.SetCustomHandler([scalar_type, size, stream, + cmd.SetCustomHandler([scalar_type, size, stream, input_ptr, masked_input_ptr, mask_ptr, org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, @@ -194,7 +193,7 @@ std::tuple get_masked_input_and_mask( get_masked_input_and_mask_impl( stream, input_ptr, - masked_input_ptr, + masked_input_ptr, mask_ptr, org_vocab_start_index, org_vocab_end_index, @@ -204,7 +203,7 @@ std::tuple get_masked_input_and_mask( size, loop_cnt, aiv_num); - + return 0; }); cmd.Run(); @@ -321,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_shrink"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, - seq_len_ptr, seq_len_size, y_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, + seq_len_ptr, seq_len_size, y_ptr, batch_size, input_hidden_token, lora_rank, scale_f]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -331,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, - y_ptr, batch_size, + y_ptr, batch_size, num_tokens_per_core, input_hidden_token, lora_rank, scale_f); return 0; }); @@ -368,7 +367,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_expand"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -376,7 +375,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, + sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim); return 0; }); @@ -385,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic } } // namespace vllm_ascend -TORCH_LIBRARY_EXPAND(_C, ops) +TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) { // vLLM-Ascend custom ops ops.def("weak_ref_tensor(Tensor input) -> Tensor"); @@ -424,5 +423,3 @@ TORCH_LIBRARY_EXPAND(_C, ops) " int slice_offset, int slice_size) -> Tensor"); ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand); } - -REGISTER_EXTENSION(_C) diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index d69254b..4101ee7 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -40,7 +40,7 @@ std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, at::Tensor &key, - int64_t head_size, + int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { auto num_tokens = positions.sym_numel(); @@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ } // namespace vllm_ascend namespace { - // Register the meta implementations of the custom kernels for symbolic tracing, this will also + // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph - TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) { + TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation @@ -99,4 +99,4 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); } -} \ No newline at end of file +} diff --git a/docs/source/community/versioning_policy.md b/docs/source/community/versioning_policy.md index a245dda..8465570 100644 --- a/docs/source/community/versioning_policy.md +++ b/docs/source/community/versioning_policy.md @@ -22,6 +22,8 @@ Following is the Release Compatibility Matrix for vLLM Ascend Plugin: | vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | MindIE Turbo | |-------------|--------------|------------------|-------------|--------------------|--------------| +| v0.11.0rc0 | v0.11.0rc3 | >= 3.9, < 3.12 | 8.2.RC1 | 2.7.1 / 2.7.1.dev20250724 | | +| v0.10.2rc1 | v0.10.2 | >= 3.9, < 3.12 | 8.2.RC1 | 2.7.1 / 2.7.1.dev20250724 | | | v0.10.1rc1 | v0.10.1/v0.10.1.1 | >= 3.9, < 3.12 | 8.2.RC1 | 2.7.1 / 2.7.1.dev20250724 | | | v0.10.0rc1 | v0.10.0 | >= 3.9, < 3.12 | 8.2.RC1 | 2.7.1 / 2.7.1.dev20250724 | | | v0.9.2rc1 | v0.9.2 | >= 3.9, < 3.12 | 8.1.RC1 | 2.5.1 / 2.5.1.post1.dev20250619 | | @@ -42,6 +44,8 @@ Following is the Release Compatibility Matrix for vLLM Ascend Plugin: | Date | Event | |------------|-------------------------------------------| +| 2025.09.30 | Release candidates, v0.11.0rc0 | +| 2025.09.16 | Release candidates, v0.10.2rc1 | | 2025.09.04 | Release candidates, v0.10.1rc1 | | 2025.09.03 | v0.9.1 Final release | | 2025.08.22 | Release candidates, v0.9.1rc3 | diff --git a/docs/source/conf.py b/docs/source/conf.py index 82d7a28..d864a3b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -65,19 +65,19 @@ myst_substitutions = { # the branch of vllm, used in vllm clone # - main branch: 'main' # - vX.Y.Z branch: 'vX.Y.Z' - 'vllm_version': 'v0.10.1.1', + 'vllm_version': 'v0.11.0rc3', # the branch of vllm-ascend, used in vllm-ascend clone and image tag # - main branch: 'main' # - vX.Y.Z branch: latest vllm-ascend release tag - 'vllm_ascend_version': 'v0.10.1rc1', + 'vllm_ascend_version': 'v0.11.0rc0', # the newest release version of vllm-ascend and matched vLLM, used in pip install. # This value should be updated when cut down release. - 'pip_vllm_ascend_version': "0.10.1rc1", - 'pip_vllm_version': "0.10.1.1", + 'pip_vllm_ascend_version': "0.11.0rc0", + 'pip_vllm_version': "0.11.0", # CANN image tag 'cann_image_tag': "8.2.rc1-910b-ubuntu22.04-py3.11", # vllm version in ci - 'ci_vllm_version': 'v0.10.1.1', + 'ci_vllm_version': 'v0.11.0rc3', } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/source/developer_guide/evaluation/accuracy_report/DeepSeek-V2-Lite.md b/docs/source/developer_guide/evaluation/accuracy_report/DeepSeek-V2-Lite.md new file mode 100644 index 0000000..68d4369 --- /dev/null +++ b/docs/source/developer_guide/evaluation/accuracy_report/DeepSeek-V2-Lite.md @@ -0,0 +1,20 @@ +# deepseek-ai/DeepSeek-V2-Lite + +- **vLLM Version**: vLLM: 0.10.1.1 ([1da94e6](https://github.com/vllm-project/vllm/commit/1da94e6)), **vLLM Ascend Version**: v0.10.1rc1 ([7e16b4a](https://github.com/vllm-project/vllm-ascend/commit/7e16b4a)) +- **Software Environment**: **CANN**: 8.2.RC1, **PyTorch**: 2.7.1, **torch-npu**: 2.7.1.dev20250724 +- **Hardware Environment**: Atlas A2 Series +- **Parallel mode**: TP2 +- **Execution mode**: ACLGraph + +**Command**: + +```bash +export MODEL_ARGS='pretrained=deepseek-ai/DeepSeek-V2-Lite,tensor_parallel_size=2,dtype=auto,trust_remote_code=True,max_model_len=4096,enforce_eager=True' +lm_eval --model vllm --model_args $MODEL_ARGS --tasks gsm8k \ + --batch_size auto +``` + +| Task | Metric | Value | Stderr | +|-----------------------|-------------|----------:|-------:| +| gsm8k | exact_match,strict-match | ✅0.3813 | ± 0.0134 | +| gsm8k | exact_match,flexible-extract | ✅0.3836 | ± 0.0134 | diff --git a/docs/source/developer_guide/evaluation/accuracy_report/Qwen2.5-VL-7B-Instruct.md b/docs/source/developer_guide/evaluation/accuracy_report/Qwen2.5-VL-7B-Instruct.md new file mode 100644 index 0000000..6ceff53 --- /dev/null +++ b/docs/source/developer_guide/evaluation/accuracy_report/Qwen2.5-VL-7B-Instruct.md @@ -0,0 +1,19 @@ +# Qwen/Qwen2.5-VL-7B-Instruct + +- **vLLM Version**: vLLM: 0.10.1.1 ([1da94e6](https://github.com/vllm-project/vllm/commit/1da94e6)), **vLLM Ascend Version**: v0.10.1rc1 ([7e16b4a](https://github.com/vllm-project/vllm-ascend/commit/7e16b4a)) +- **Software Environment**: **CANN**: 8.2.RC1, **PyTorch**: 2.7.1, **torch-npu**: 2.7.1.dev20250724 +- **Hardware Environment**: Atlas A2 Series +- **Parallel mode**: TP1 +- **Execution mode**: ACLGraph + +**Command**: + +```bash +export MODEL_ARGS='pretrained=Qwen/Qwen2.5-VL-7B-Instruct,tensor_parallel_size=1,dtype=auto,trust_remote_code=False,max_model_len=8192' +lm_eval --model vllm-vlm --model_args $MODEL_ARGS --tasks mmmu_val \ + --apply_chat_template True --fewshot_as_multiturn True --batch_size auto +``` + +| Task | Metric | Value | Stderr | +|-----------------------|-------------|----------:|-------:| +| mmmu_val | acc,none | ✅0.52 | ± 0.0162 | diff --git a/docs/source/developer_guide/evaluation/accuracy_report/Qwen3-30B-A3B.md b/docs/source/developer_guide/evaluation/accuracy_report/Qwen3-30B-A3B.md new file mode 100644 index 0000000..d170936 --- /dev/null +++ b/docs/source/developer_guide/evaluation/accuracy_report/Qwen3-30B-A3B.md @@ -0,0 +1,21 @@ +# Qwen/Qwen3-30B-A3B + +- **vLLM Version**: vLLM: 0.10.1.1 ([1da94e6](https://github.com/vllm-project/vllm/commit/1da94e6)), **vLLM Ascend Version**: v0.10.1rc1 ([7e16b4a](https://github.com/vllm-project/vllm-ascend/commit/7e16b4a)) +- **Software Environment**: **CANN**: 8.2.RC1, **PyTorch**: 2.7.1, **torch-npu**: 2.7.1.dev20250724 +- **Hardware Environment**: Atlas A2 Series +- **Parallel mode**: TP2 + EP +- **Execution mode**: ACLGraph + +**Command**: + +```bash +export MODEL_ARGS='pretrained=Qwen/Qwen3-30B-A3B,tensor_parallel_size=2,dtype=auto,trust_remote_code=False,max_model_len=4096,gpu_memory_utilization=0.6,enable_expert_parallel=True' +lm_eval --model vllm --model_args $MODEL_ARGS --tasks gsm8k,ceval-valid \ + --num_fewshot 5 --batch_size auto +``` + +| Task | Metric | Value | Stderr | +|-----------------------|-------------|----------:|-------:| +| gsm8k | exact_match,strict-match | ✅0.8923 | ± 0.0085 | +| gsm8k | exact_match,flexible-extract | ✅0.8506 | ± 0.0098 | +| ceval-valid | acc,none | ✅0.8358 | ± 0.0099 | diff --git a/docs/source/developer_guide/evaluation/accuracy_report/Qwen3-8B-Base.md b/docs/source/developer_guide/evaluation/accuracy_report/Qwen3-8B-Base.md new file mode 100644 index 0000000..0649ee6 --- /dev/null +++ b/docs/source/developer_guide/evaluation/accuracy_report/Qwen3-8B-Base.md @@ -0,0 +1,21 @@ +# Qwen/Qwen3-8B-Base + +- **vLLM Version**: vLLM: 0.10.1.1 ([1da94e6](https://github.com/vllm-project/vllm/commit/1da94e6)), **vLLM Ascend Version**: v0.10.1rc1 ([7e16b4a](https://github.com/vllm-project/vllm-ascend/commit/7e16b4a)) +- **Software Environment**: **CANN**: 8.2.RC1, **PyTorch**: 2.7.1, **torch-npu**: 2.7.1.dev20250724 +- **Hardware Environment**: Atlas A2 Series +- **Parallel mode**: TP1 +- **Execution mode**: ACLGraph + +**Command**: + +```bash +export MODEL_ARGS='pretrained=Qwen/Qwen3-8B-Base,tensor_parallel_size=1,dtype=auto,trust_remote_code=False,max_model_len=4096' +lm_eval --model vllm --model_args $MODEL_ARGS --tasks gsm8k,ceval-valid \ + --apply_chat_template True --fewshot_as_multiturn True --num_fewshot 5 --batch_size auto +``` + +| Task | Metric | Value | Stderr | +|-----------------------|-------------|----------:|-------:| +| gsm8k | exact_match,strict-match | ✅0.8271 | ± 0.0104 | +| gsm8k | exact_match,flexible-extract | ✅0.8294 | ± 0.0104 | +| ceval-valid | acc,none | ✅0.815 | ± 0.0103 | diff --git a/docs/source/developer_guide/evaluation/accuracy_report/index.md b/docs/source/developer_guide/evaluation/accuracy_report/index.md index 0ed0a18..59f7f23 100644 --- a/docs/source/developer_guide/evaluation/accuracy_report/index.md +++ b/docs/source/developer_guide/evaluation/accuracy_report/index.md @@ -3,4 +3,8 @@ :::{toctree} :caption: Accuracy Report :maxdepth: 1 +DeepSeek-V2-Lite +Qwen2.5-VL-7B-Instruct +Qwen3-30B-A3B +Qwen3-8B-Base ::: diff --git a/docs/source/developer_guide/modeling/adding_a_new_model.md b/docs/source/developer_guide/modeling/adding_a_new_model.md index 117f559..5762fde 100644 --- a/docs/source/developer_guide/modeling/adding_a_new_model.md +++ b/docs/source/developer_guide/modeling/adding_a_new_model.md @@ -61,7 +61,6 @@ from torch import nn from vllm.attention import Attention from vllm.config import VllmConfig from vllm.sequence import IntermediateTensors -from vllm.model_executor.sampling_metadata import SamplingMetadata class CustomAttention(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str): diff --git a/docs/source/faqs.md b/docs/source/faqs.md index c0a3f0d..ec7f339 100644 --- a/docs/source/faqs.md +++ b/docs/source/faqs.md @@ -3,7 +3,7 @@ ## Version Specific FAQs - [[v0.9.1] FAQ & Feedback](https://github.com/vllm-project/vllm-ascend/issues/2643) -- [[v0.10.1rc1] FAQ & Feedback](https://github.com/vllm-project/vllm-ascend/issues/2630) +- [[v0.11.0rc1] FAQ & Feedback](https://github.com/vllm-project/vllm-ascend/issues/3222) ## General FAQs @@ -196,3 +196,21 @@ export ATB_LLM_LCOC_ENABLE=0 ### 19. How to fix the error "ImportError: Please install vllm[audio] for audio support" for Qwen2.5-Omni model? The `Qwen2.5-Omni` model requires the `librosa` package to be installed, you need to install the `qwen-omni-utils` package to ensure all dependencies are met `pip install qwen-omni-utils`, this package will install `librosa` and its related dependencies, resolving the `ImportError: No module named 'librosa'` issue and ensuring audio processing functionality works correctly. + +### 20. How to troubleshoot and resolve size capture failures resulting from stream resource exhaustion, and what are the underlying causes? + +``` +error example in detail: +ERROR 09-26 10:48:07 [model_runner_v1.py:3029] ACLgraph sizes capture fail: RuntimeError: +ERROR 09-26 10:48:07 [model_runner_v1.py:3029] ACLgraph has insufficient available streams to capture the configured number of sizes.Please verify both the availability of adequate streams and the appropriateness of the configured size count. +``` + +Recommended mitigation strategies: +1. Manually configure the compilation_config parameter with a reduced size set: '{"cudagraph_capture_sizes":[size1, size2, size3, ...]}'. +2. Employ ACLgraph's full graph mode as an alternative to the piece-wise approach. + +Root cause analysis: +The current stream requirement calculation for size captures only accounts for measurable factors including: data parallel size, tensor parallel size, expert parallel configuration, piece graph count, multistream overlap shared expert settings, and HCCL communication mode (AIV/AICPU). However, numerous unquantifiable elements - such as operator characteristics and specific hardware features - consume additional streams outside of this calculation framework, resulting in stream resource exhaustion during size capture operations. + +### 21. Installing vllm-ascend will overwrite the existing torch-npu package? +Installing vllm-ascend will overwrite the existing torch-npu package. If you need to install a specific version of torch-npu, you can manually install the specified version of torch-npu after installing vllm-ascend. diff --git a/docs/source/installation.md b/docs/source/installation.md index b06777e..0d3b54d 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -11,6 +11,7 @@ This document describes how to install vllm-ascend manually. | Software | Supported version | Note | |---------------|----------------------------------|-------------------------------------------| + | Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN | | CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu | | torch-npu | >= 2.7.1.dev20250724 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps | | torch | >= 2.7.1 | Required for torch-npu and vllm | diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po b/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po index 54dacd6..b60df5a 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po @@ -148,10 +148,6 @@ msgid "" " to be passed in." msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。" -#: ../../user_guide/configuration/additional_config.md -msgid "`chunked_prefill_for_mla`" -msgstr "`chunked_prefill_for_mla`" - #: ../../user_guide/configuration/additional_config.md msgid "`False`" msgstr "`False`" @@ -199,8 +195,8 @@ msgid "" msgstr "是否将MLA的向量操作放到另一个流中。此选项仅对使用MLA的模型(例如,DeepSeek)有效。" #: ../../user_guide/configuration/additional_config.md -msgid "`enable_multistream_moe`" -msgstr "`enable_multistream_moe`" +msgid "`multistream_overlap_shared_expert`" +msgstr "`multistream_overlap_shared_expert`" #: ../../user_guide/configuration/additional_config.md msgid "" diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 3c9d38f..971e6e0 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -8,6 +8,7 @@ single_npu_multimodal single_npu_audio single_npu_qwen3_embedding single_npu_qwen3_quantization +multi_npu_qwen3_next multi_npu multi_npu_moge multi_npu_qwen3_moe @@ -15,4 +16,7 @@ multi_npu_quantization single_node_300i multi_node multi_node_kimi +multi_node_qwen3vl +multi_node_pd_disaggregation +multi_node_ray ::: diff --git a/docs/source/tutorials/multi_node_pd_disaggregation.md b/docs/source/tutorials/multi_node_pd_disaggregation.md new file mode 100644 index 0000000..e334973 --- /dev/null +++ b/docs/source/tutorials/multi_node_pd_disaggregation.md @@ -0,0 +1,244 @@ +# Prefill-Decode Disaggregation Verification (Qwen) + +## Getting Start + +vLLM-Ascend now supports prefill-decode (PD) disaggregation with EP (Expert Parallel) options. This guide take one-by-one steps to verify these features with constrained resources. + +Take the Qwen3-30B-A3B model as an example, use vllm-ascend v0.10.1rc1 (with vLLM v0.10.1.1) on 3 Atlas 800T A2 servers to deploy the "1P2D" architecture. Assume the ip of the prefiller server is 192.0.0.1, and the decoder servers are 192.0.0.2 (decoder 1) and 192.0.0.3 (decoder 2). On each server, use 2 NPUs to deploy one service instance. + +## Verify Multi-Node Communication Environment + +### Physical Layer Requirements + +- The physical machines must be located on the same WLAN, with network connectivity. +- All NPUs must be interconnected. Intra-node connectivity is via HCCS, and inter-node connectivity is via RDMA. + +### Verification Process + +1. Single Node Verification: + +Execute the following commands on each node in sequence. The results must all be `success` and the status must be `UP`: + +```bash +# Check the remote switch ports +for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done +# Get the link status of the Ethernet ports (UP or DOWN) +for i in {0..7}; do hccn_tool -i $i -link -g ; done +# Check the network health status +for i in {0..7}; do hccn_tool -i $i -net_health -g ; done +# View the network detected IP configuration +for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done +# View gateway configuration +for i in {0..7}; do hccn_tool -i $i -gateway -g ; done +# View NPU network configuration +cat /etc/hccn.conf +``` + +2. Get NPU IP Addresses + +```bash +for i in {0..7}; do hccn_tool -i $i -ip -g;done +``` + +3. Cross-Node PING Test + +```bash +# Execute on the target node (replace 'x.x.x.x' with actual npu ip address) +for i in {0..7}; do hccn_tool -i $i -ping -g address x.x.x.x;done +``` + +## Generate Ranktable + +The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. For more details please refer to the [vllm-ascend examples](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/README.md). Execute the following commands for reference. + +```shell +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips \ + --npus-per-node --network-card-name --prefill-device-cnt --decode-device-cnt \ + [--local-device-ids ,,...] +``` + +Assume that we use device 0,1 on the prefiller server node and device 6,7 on both of the decoder server nodes. Take the following commands as an example. (`--local-device-ids` is necessary if you specify certain NPU devices on the local server.) + +```shell +# On the prefiller node +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips 192.0.0.1 192.0.0.2 192.0.0.3 \ + --npus-per-node 2 --network-card-name eth0 --prefill-device-cnt 2 --decode-device-cnt 4 --local-device-ids 0,1 + +# On the decoder 1 +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips 192.0.0.1 192.0.0.2 192.0.0.3 \ + --npus-per-node 2 --network-card-name eth0 --prefill-device-cnt 2 --decode-device-cnt 4 --local-device-ids 6,7 + +# On the decoder 2 +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips 192.0.0.1 192.0.0.2 192.0.0.3 \ + --npus-per-node 2 --network-card-name eth0 --prefill-device-cnt 2 --decode-device-cnt 4 --local-device-ids 6,7 +``` + +Rank table will generated at /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json + +|Parameter | meaning | +| --- | --- | +| --ips | Each node's local ip (prefiller nodes should be front of decoder nodes) | +| --npus-per-node | Each node's npu clips | +| --network-card-name | The physical machines' NIC | +|--prefill-device-cnt | Npu clips used for prefill | +|--decode-device-cnt |Npu clips used for decode | +|--local-device-ids |Optional. No need if using all devices on the local node. | + +## Prefiller / Decoder Deployment + +We can run the following scripts to launch a server on the prefiller/decoder node respectively. + +:::::{tab-set} + +::::{tab-item} Prefiller node + +```shell +export HCCL_IF_IP=192.0.0.1 # node ip +export GLOO_SOCKET_IFNAME="eth0" # network card name +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="/path/to/your/generated/ranktable.json" +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=10 +export VLLM_USE_V1=1 + +vllm serve /model/Qwen3-30B-A3B \ + --host 0.0.0.0 \ + --port 13700 \ + --tensor-parallel-size 2 \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name qwen3-moe \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --enable-expert-parallel \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' \ + --enforce-eager +``` + +:::: + +::::{tab-item} Decoder node 1 + +```shell +export HCCL_IF_IP=192.0.0.2 # node ip +export GLOO_SOCKET_IFNAME="eth0" # network card name +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="/path/to/your/generated/ranktable.json" +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=10 +export VLLM_USE_V1=1 + +vllm serve /model/Qwen3-30B-A3B \ + --host 0.0.0.0 \ + --port 13700 \ + --no-enable-prefix-caching \ + --tensor-parallel-size 2 \ + --seed 1024 \ + --served-model-name qwen3-moe \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --enable-expert-parallel \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +:::: + +::::{tab-item} Decoder node 2 + +```shell +export HCCL_IF_IP=192.0.0.3 # node ip +export GLOO_SOCKET_IFNAME="eth0" # network card name +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="/path/to/your/generated/ranktable.json" +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=10 +export VLLM_USE_V1=1 + +vllm serve /model/Qwen3-30B-A3B \ + --host 0.0.0.0 \ + --port 13700 \ + --no-enable-prefix-caching \ + --tensor-parallel-size 2 \ + --seed 1024 \ + --served-model-name qwen3-moe \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --enable-expert-parallel \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +:::: + +::::: + +## Example proxy for Deployment + +Run a proxy server on the same node with prefiller service instance. You can get the proxy program in the repository's examples: [load\_balance\_proxy\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py) + +```shell +python load_balance_proxy_server_example.py \ + --host 192.0.0.1 \ + --port 8080 \ + --prefiller-hosts 192.0.0.1 \ + --prefiller-port 13700 \ + --decoder-hosts 192.0.0.2 192.0.0.3 \ + --decoder-ports 13700 13700 +``` + +## Verification + +Check service health using the proxy server endpoint. + +```shell +curl http://192.0.0.1:8080/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen3-moe", + "prompt": "Who are you?", + "max_tokens": 100, + "temperature": 0 + }' +``` diff --git a/docs/source/tutorials/multi_node_qwen3vl.md b/docs/source/tutorials/multi_node_qwen3vl.md new file mode 100644 index 0000000..40a4d2a --- /dev/null +++ b/docs/source/tutorials/multi_node_qwen3vl.md @@ -0,0 +1,156 @@ +# Multi-Node-DP (Qwen3-VL-235B-A22B) + +## Verify Multi-Node Communication Environment + +referring to [multi_node.md](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_node.html#verification-process) + +## Run with docker +Assume you have an Atlas 800 A3(64G*16) nodes(or 2 * A2), and want to deploy the `Qwen3-VL-235B-A22B-Instruct` model across multi-node. + +```{code-block} bash + :substitutions: +# Update the vllm-ascend image +export IMAGE=quay.io/ascend/vllm-ascend:|vllm_ascend_version| +docker run --rm \ +--name vllm-ascend \ +--net=host \ +--device /dev/davinci0 \ +--device /dev/davinci1 \ +--device /dev/davinci2 \ +--device /dev/davinci3 \ +--device /dev/davinci4 \ +--device /dev/davinci5 \ +--device /dev/davinci6 \ +--device /dev/davinci7 \ +--device /dev/davinci8 \ +--device /dev/davinci9 \ +--device /dev/davinci10 \ +--device /dev/davinci11 \ +--device /dev/davinci12 \ +--device /dev/davinci13 \ +--device /dev/davinci14 \ +--device /dev/davinci15 \ +--device /dev/davinci_manager \ +--device /dev/devmm_svm \ +--device /dev/hisi_hdc \ +-v /usr/local/dcmi:/usr/local/dcmi \ +-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ +-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /root/.cache:/root/.cache \ +-p 8000:8000 \ +-it $IMAGE bash +``` + +Run the following scripts on two nodes respectively + +:::{note} +Before launch the inference server, ensure the following environment variables are set for multi node communication +::: + +node0 + +```shell +#!/bin/sh +# this obtained through ifconfig +# nic_name is the network interface name corresponding to local_ip +nic_name="xxxx" +local_ip="xxxx" + +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export HCCL_BUFFSIZE=1024 + +vllm serve Qwen/Qwen3-VL-235B-A22B-Instruct \ +--host 0.0.0.0 \ +--port 8000 \ +--data-parallel-size 2 \ +--api-server-count 2 \ +--data-parallel-size-local 1 \ +--data-parallel-address $local_ip \ +--data-parallel-rpc-port 13389 \ +--seed 1024 \ +--served-model-name qwen3vl \ +--tensor-parallel-size 8 \ +--enable-expert-parallel \ +--max-num-seqs 16 \ +--max-model-len 32768 \ +--max-num-batched-tokens 4096 \ +--trust-remote-code \ +--no-enable-prefix-caching \ +--gpu-memory-utilization 0.8 \ +``` + +node1 + +```shell +#!/bin/sh + +nic_name="xxxx" +local_ip="xxxx" +node0_ip="xxxx" + +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export HCCL_BUFFSIZE=1024 + +vllm serve Qwen/Qwen3-VL-235B-A22B-Instruct \ +--host 0.0.0.0 \ +--port 8000 \ +--headless \ +--data-parallel-size 2 \ +--data-parallel-size-local 1 \ +--data-parallel-start-rank 1 \ +--data-parallel-address $node0_ip \ +--data-parallel-rpc-port 13389 \ +--seed 1024 \ +--tensor-parallel-size 8 \ +--served-model-name qwen3vl \ +--max-num-seqs 16 \ +--max-model-len 32768 \ +--max-num-batched-tokens 4096 \ +--enable-expert-parallel \ +--trust-remote-code \ +--no-enable-prefix-caching \ +--gpu-memory-utilization 0.8 \ +``` + +If the service starts successfully, the following information will be displayed on node0: + +```shell +INFO: Started server process [44610] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Started server process [44611] +INFO: Waiting for application startup. +INFO: Application startup complete. +``` + +Once your server is started, you can query the model with input prompts: + +```shell +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen3vl", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png"}}, + {"type": "text", "text": "What is the text in the illustrate?"} + ]} + ] + }' +``` diff --git a/docs/source/tutorials/multi_node_ray.md b/docs/source/tutorials/multi_node_ray.md new file mode 100644 index 0000000..ad1a8d6 --- /dev/null +++ b/docs/source/tutorials/multi_node_ray.md @@ -0,0 +1,182 @@ +# Multi-Node-Ray (Qwen/Qwen3-235B-A22B) + +Multi-node inference is suitable for the scenarios that the model cannot be deployed on a single machine. In such cases, the model can be distributed using tensor parallelism or pipeline parallelism. The specific parallelism strategies will be covered in the following sections. To successfully deploy multi-node inference, the following three steps need to be completed: + +* **Verify Multi-Node Communication Environment** +* **Set Up and Start the Ray Cluster** +* **Start the Online Inference Service on multinode** + +## Verify Multi-Node Communication Environment + +### Physical Layer Requirements: + +* The physical machines must be located on the same LAN, with network connectivity. +* All NPUs are connected with optical modules, and the connection status must be normal. + +### Verification Process: + +Execute the following commands on each node in sequence. The results must all be `success` and the status must be `UP`: + +```bash + # Check the remote switch ports + for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done + # Get the link status of the Ethernet ports (UP or DOWN) + for i in {0..7}; do hccn_tool -i $i -link -g ; done + # Check the network health status + for i in {0..7}; do hccn_tool -i $i -net_health -g ; done + # View the network detected IP configuration + for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done + # View gateway configuration + for i in {0..7}; do hccn_tool -i $i -gateway -g ; done + # View NPU network configuration + cat /etc/hccn.conf +``` + +### NPU Interconnect Verification: +#### 1. Get NPU IP Addresses + +```bash +for i in {0..7}; do hccn_tool -i $i -ip -g | grep ipaddr; done +``` + +#### 2. Cross-Node PING Test + +```bash +# Execute on the target node (replace with actual IP) +hccn_tool -i 0 -ping -g address 10.20.0.20 +``` + +## Set Up and Start the Ray Cluster +### Setting Up the Basic Container +To ensure a consistent execution environment across all nodes, including the model path and Python environment, it is recommended to use Docker images. + +For setting up a multi-node inference cluster with Ray, **containerized deployment** is the preferred approach. Containers should be started on both the master and worker nodes, with the `--net=host` option to enable proper network connectivity. + +Below is the example container setup command, which should be executed on **all nodes** : + +```{code-block} bash + :substitutions: +# Update the vllm-ascend image +export IMAGE=quay.nju.edu.cn/ascend/vllm-ascend:|vllm_ascend_version| +export NAME=vllm-ascend + +# Run the container using the defined variables +# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance +docker run --rm \ +--name $NAME \ +--net=host \ +--device /dev/davinci0 \ +--device /dev/davinci1 \ +--device /dev/davinci2 \ +--device /dev/davinci3 \ +--device /dev/davinci4 \ +--device /dev/davinci5 \ +--device /dev/davinci6 \ +--device /dev/davinci7 \ +--device /dev/davinci_manager \ +--device /dev/devmm_svm \ +--device /dev/hisi_hdc \ +-v /usr/local/dcmi:/usr/local/dcmi \ +-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ +-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /path/to/shared/cache:/root/.cache \ # IMPORTANT: This must be a shared directory accessible by all nodes +-it $IMAGE bash +``` + +### Start Ray Cluster +After setting up the containers and installing vllm-ascend on each node, follow the steps below to start the Ray cluster and execute inference tasks. + +Choose one machine as the head node and the others as worker nodes. Before proceeding, use `ip addr` to check your `nic_name` (network interface name). + +Set the `ASCEND_RT_VISIBLE_DEVICES` environment variable to specify the NPU devices to use. For Ray versions above 2.1, also set the `RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES` variable to avoid device recognition issues. + +Below are the commands for the head and worker nodes: + +**Head node**: + +:::{note} +When starting a Ray cluster for multi-node inference, the environment variables on each node must be set **before** starting the Ray cluster for them to take effect. +Updating the environment variables requires restarting the Ray cluster. +::: + +```shell +# Head node +export HCCL_IF_IP={local_ip} +export GLOO_SOCKET_IFNAME={nic_name} +export TP_SOCKET_IFNAME={nic_name} +export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +ray start --head +``` + +**Worker node**: + +:::{note} +When starting a Ray cluster for multi-node inference, the environment variables on each node must be set **before** starting the Ray cluster for them to take effect. Updating the environment variables requires restarting the Ray cluster. +::: + +```shell +# Worker node +export HCCL_IF_IP={local_ip} +export GLOO_SOCKET_IFNAME={nic_name} +export TP_SOCKET_IFNAME={nic_name} +export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +ray start --address='{head_node_ip}:6379' --node-ip-address={local_ip} +``` + +Once the cluster is started on multiple nodes, execute `ray status` and `ray list nodes` to verify the Ray cluster's status. You should see the correct number of nodes and NPUs listed. + +## Start the Online Inference Service on multinode scenario +In the container, you can use vLLM as if all NPUs were on a single node. vLLM will utilize NPU resources across all nodes in the Ray cluster. + +**You only need to run the vllm command on one node.** + +To set up parallelism, the common practice is to set the `tensor-parallel-size` to the number of NPUs per node, and the `pipeline-parallel-size` to the number of nodes. + +For example, with 16 NPUs across 2 nodes (8 NPUs per node), set the tensor parallel size to 8 and the pipeline parallel size to 2: + +```shell +vllm serve Qwen/Qwen3-235B-A22B \ + --distributed-executor-backend ray \ + --pipeline-parallel-size 2 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --seed 1024 \ + --max-model-len 8192 \ + --max-num-seqs 25 \ + --served-model-name qwen \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 +``` + +Alternatively, if you want to use only tensor parallelism, set the tensor parallel size to the total number of NPUs in the cluster. For example, with 16 NPUs across 2 nodes, set the tensor parallel size to 16: + +```shell +vllm serve Qwen/Qwen3-235B-A22B \ + --distributed-executor-backend ray \ + --tensor-parallel-size 16 \ + --enable-expert-parallel \ + --seed 1024 \ + --max-model-len 8192 \ + --max-num-seqs 25 \ + --served-model-name qwen \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 +``` + +Once your server is started, you can query the model with input prompts: + +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen", + "prompt": "tell me how to sleep well", + "max_tokens": 100, + "temperature": 0 + }' +``` diff --git a/docs/source/tutorials/multi_npu_qwen3_next.md b/docs/source/tutorials/multi_npu_qwen3_next.md new file mode 100644 index 0000000..4fa5861 --- /dev/null +++ b/docs/source/tutorials/multi_npu_qwen3_next.md @@ -0,0 +1,156 @@ +# Multi-NPU (Qwen3-Next) + +```{note} +The Qwen3 Next are using [Triton Ascend](https://gitee.com/ascend/triton-ascend) which is currently experimental. In future versions, there may be behavioral changes around stability, accuracy and performance improvement. +``` + +## Run vllm-ascend on Multi-NPU with Qwen3 Next + +Run docker container: + +```{code-block} bash + :substitutions: +# Update the vllm-ascend image +export IMAGE=quay.io/ascend/vllm-ascend:|vllm_ascend_version| +docker run --rm \ +--name vllm-ascend-qwen3 \ +--device /dev/davinci0 \ +--device /dev/davinci1 \ +--device /dev/davinci2 \ +--device /dev/davinci3 \ +--device /dev/davinci_manager \ +--device /dev/devmm_svm \ +--device /dev/hisi_hdc \ +-v /usr/local/dcmi:/usr/local/dcmi \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ +-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /root/.cache:/root/.cache \ +-p 8000:8000 \ +-it $IMAGE bash +``` + +Setup environment variables: + +```bash +# Load model from ModelScope to speed up download +export VLLM_USE_MODELSCOPE=True +``` + +### Install Triton Ascend + +:::::{tab-set} +::::{tab-item} Linux (aarch64) + +The [Triton Ascend](https://gitee.com/ascend/triton-ascend) is required when you run Qwen3 Next, please follow the instructions below to install it and its dependency. + +Install the Ascend BiSheng toolkit: + +```bash +wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/Ascend-BiSheng-toolkit_aarch64.run +chmod a+x Ascend-BiSheng-toolkit_aarch64.run +./Ascend-BiSheng-toolkit_aarch64.run --install +source /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh +``` + +Install Triton Ascend: + +```bash +wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl +pip install triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl +``` + +:::: + +::::{tab-item} Linux (x86_64) + +Coming soon ... + +:::: +::::: + +### Inference on Multi-NPU + +Please make sure you already executed the command: + +```bash +source /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh +``` + +:::::{tab-set} +::::{tab-item} Online Inference + +Run the following script to start the vLLM server on Multi-NPU: + +For an Atlas A2 with 64GB of NPU card memory, tensor-parallel-size should be at least 4, and for 32GB of memory, tensor-parallel-size should be at least 8. + +```bash +vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 4096 --gpu-memory-utilization 0.7 --enforce-eager +``` + +Once your server is started, you can query the model with input prompts + +```bash +curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "Qwen/Qwen3-Next-80B-A3B-Instruct", + "messages": [ + {"role": "user", "content": "Who are you?"} + ], + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 32 +}' +``` + +:::: + +::::{tab-item} Offline Inference + +Run the following script to execute offline inference on multi-NPU: + +```python +import gc +import torch + +from vllm import LLM, SamplingParams +from vllm.distributed.parallel_state import (destroy_distributed_environment, + destroy_model_parallel) + +def clean_up(): + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + torch.npu.empty_cache() + +if __name__ == '__main__': + prompts = [ + "Who are you?", + ] + sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40, max_tokens=32) + llm = LLM(model="Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + enforce_eager=True, + distributed_executor_backend="mp", + gpu_memory_utilization=0.7, + max_model_len=4096) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + del llm + clean_up() +``` + +If you run this script successfully, you can see the info shown below: + +```bash +Prompt: 'Who are you?', Generated text: ' What do you know about me?\n\nHello! I am Qwen, a large-scale language model independently developed by the Tongyi Lab under Alibaba Group. I am' +``` + +:::: +::::: diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index c67f340..e709b3a 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -30,11 +30,18 @@ The following table lists the additional configuration options available in vLLM | `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | -| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | | `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | | `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | +| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. | +| `multistream_overlap_shared_expert`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. | +| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb | +|`num_iterations_eplb_update`| int | `400` | Forward iterations when eplb would begin | +|`gate_eplb`| bool | `False` | Whether to enale eplb only once. | +|`num_wait_worker_iterations`| int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. | +|`expert_map_record_path`| str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. | +|`init_redundancy_expert`| int | `0` |Specify redundant experts during initialization.| The details of each config option are as follows: @@ -45,8 +52,8 @@ The details of each config option are as follows: | `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode | | `mode` | str | `None` | When using reduce-overhead mode for torchair, mode needs to be set | | `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). | -| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | +| `enable_frozen_parameter` | bool | `True` | Whether to fix the memory address of weights during inference to reduce the input address refresh time during graph execution. | | `use_cached_graph` | bool | `False` | Whether to use cached graph | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | @@ -57,6 +64,10 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine| +| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. | +| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. | +| `max_long_partial_prefills` | Union[int, float] | `float('inf')` | the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. | +| `long_prefill_token_threshold` | Union[int, float] | `float('inf')` | a request is considered long if the prompt is longer than this number of tokens. | ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. @@ -71,13 +82,15 @@ An example of additional configuration is as follows: "use_cached_graph": True, "graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes_init": False, - "enable_multistream_moe": False, "enable_kv_nz": False }, "ascend_scheduler_config": { "enabled": True, "enable_chunked_prefill": True, + "max_long_partial_prefills": 1, + "long_prefill_token_threshold": 4096, }, + "multistream_overlap_shared_expert": True, "refresh": False, } ``` diff --git a/docs/source/user_guide/feature_guide/eplb_swift_balancer.md b/docs/source/user_guide/feature_guide/eplb_swift_balancer.md new file mode 100644 index 0000000..707b263 --- /dev/null +++ b/docs/source/user_guide/feature_guide/eplb_swift_balancer.md @@ -0,0 +1,94 @@ +# Expert Load Balance (EPLB) + +## Overview + +Expert balancing for MoE models in LLM serving is essential for optimal performance. Dynamically changing experts during inference can negatively impact TTFT (Time To First Token) and TPOT (Tokens Per Output Token) due to stop-the-world operations. SwiftBalancer enables asynchronous expert load balancing with zero-overhead expert movement, ensuring seamless service continuity. + +## EPLB Effects + +- Reduced Latency: Dynamically balances expert loads to minimize TTFT and TPOT by distributing workloads evenly across experts. +- Enhanced Throughput: Optimizes GPU utilization, increasing token generation speed under high-concurrency scenarios. +- Zero-Overhead Movement: Expert redistribution occurs asynchronously without interrupting ongoing inference requests. +- Adaptive Scaling: Automatically adjusts to workload fluctuations while maintaining stable performance. +- Fault Tolerance: Redundant expert placement ensures system resilience during hardware failures. + +## How to Use EPLB + +### Dynamic EPLB + +Enable dynamic balancing with auto-tuned parameters. Adjust num_iterations_eplb_update and num_wait_worker_iterations based on workload patterns. + +```shell +vllm serve Qwen/Qwen3-235B-A22 \ + --tensor-parallel-size 16 \ + --enable-expert-parallel \ + --additional-config '{ + "dynamic_eplb": true, + "num_iterations_eplb_update": 400, + "gate_eplb": true, + "num_wait_worker_iterations": 30 + }' +``` + +### Static EPLB +#### Initial Setup (Record Expert Map) + +Generate the initial expert distribution map using expert_map_record_path. This creates a baseline configuration for future deployments. + +```shell +vllm serve Qwen/Qwen3-235B-A22 \ + --tensor-parallel-size 16 \ + --enable-expert-parallel \ + --additional-config '{ + "expert_map_record_path": "/path/to/eplb.json", + "init_redundancy_expert": 16, + "dynamic_eplb": true, + "num_iterations_eplb_update": 400, + "gate_eplb": true, + "num_wait_worker_iterations": 30 + }' +``` + +#### Subsequent Deployments (Use Recorded Map) +Load the pre-recorded expert map for consistent performance. This avoids recalculating distributions at runtime. + +```shell +vllm serve Qwen/Qwen3-235B-A22 \ + --tensor-parallel-size 16 \ + --enable-expert-parallel \ + --additional-config '{ + "expert_map_path": "/path/to/eplb.json" + }' +``` + +## Critical Considerations +1. Parameter Tuning: + - num_iterations_eplb_update: Higher values (e.g., 400+) for stable workloads; lower values (e.g., 100-200) for fluctuating traffic. + - num_wait_worker_iterations: Should be ≥30 to avoid premature balancing during startup. + - init_redundancy_expert: Must match tensor-parallel size (e.g., 16 for 16 GPUs) to ensure sufficient redundancy. + +2. Hardware Requirements: + - Ensure all GPUs have identical memory capacity and compute capabilities. + - Network bandwidth must support expert redistribution traffic (≥10Gbps recommended). + +3. Model Compatibility: + - Only MoE models with explicit expert parallelism support (e.g., Qwen3-235B-A22) are compatible. + - Verify model architecture supports dynamic expert routing via --enable-expert-parallel. + +4. Gating Configuration: + - When gate_eplb=true, validate that the gating mechanism can handle expert movement without routing errors. + - Test with synthetic workloads before production deployment. + +5. Monitoring & Validation: + - Track metrics: expert_load_balance_ratio, ttft_p99, tpot_avg, and gpu_utilization. + - Use vllm monitor to detect imbalances during runtime. + - Always verify expert map JSON structure before loading (validate with jq or similar tools). + +6. Startup Behavior: + - Initial requests may experience higher latency during the first balancing cycle (typically 1-2 minutes). + - Avoid sudden traffic spikes during warm-up phase. + +7. Common Pitfalls: + - Incorrect tensor-parallel-size vs. actual GPU count → causes resource underutilization. + - Using expert_map_path without generating the map first → runtime errors. + - Setting init_redundancy_expert > available GPUs → system failure. diff --git a/docs/source/user_guide/feature_guide/images/eplb_img.png b/docs/source/user_guide/feature_guide/images/eplb_img.png new file mode 100644 index 0000000000000000000000000000000000000000..2888b17f4de5172d38d1a030671d90506837fdae GIT binary patch literal 56081 zcmb6BRa{lw7xxX5(j^Ge-6h?fibzPebayu!5RmRJ1*A(rx~03jyF)sBr(Xa2exH-) z;MpG!bnmtHUNgoRzcI#l8m9O`1`U}2843yt?X9e&G87anITRE$3nC172UqHg1{9P( z)LTh06?eU(OfwCf`NqQ2(aYVX_-j6|5p=iftimY#SJWkR5I7`c3NpUkh^-8#5!n$w zv$tO<$RM%_;H~$o6-^l=)yPJVg!DNUKGQ^8>v0m|iAK%^P7?J;d3*zlPR`uT;oeYHB}OT&==bSUfCe#xVls=8f6$K~ zKP2J5q6?aunUM;6*O!(asUrWIt5&`1SR-auRz^m~!0qndUN;vPGz<(0q`t{O1Pmr7 zCTC~o=9U&OPtP1vD=S7kJas+2m0%RSoaVedDtda0LmihhTG>~am{6nOdtYN?e^gR} z7JtF&si9IXnBiL`Ft} zNgJP>4AjjO@>0{$A!<_({re;j0X|zAR9{T{2L^U_cFN1kXJ%$v-Onk=$VA1&0{HKA z6%gWi98TR(hNKbESY)mWU;N?X`hA*a~qwDMMM~VgysF; z2KBP1ya=8}m{PEL}k=IBgWf_XSPI^yEu`uzDbcyta94#m0o z`696PK<;!CmC>?EWD;t|&N6XBPj*jc=YhVgv6~BH-eZ^#u zxC6T|^x>DnhY#qesO`#qb*-%*rZn~RY|YK-7#Jv8hw7`VCnhIhJ$&UO{xhcjRwt(a zynq4}Kkh%<>FY)K|LhO;G)S}e_bm>5R*=-x)N(NG!v789A;f1_gOHH$`1lwMFz9?} zXo!=W8yyWzRaG@VFAwv^EF?LaRODoDPeV)VYdJ3;pQf_1vZf|dr({CAvU_TJy4+Jk z7g+E1c2Q#D%$OL=uW-gCg@t`HGq}iv1O%bt|5>Hyb+VC-tt|}=&AYd6>FMa|8ybd( zhh1D;!0x37F9wFgwp^yn1PJi+u3r%^QCM4|jJ#uj_~ck-r-R;$3#c z+MAYB|aaYC*ZA;5)=0i4j@`$xxiI6 zH+zF=0It*C{tTGqhYurFM!gi2lwhXFyTDAv(<-!AZEbB4fN`#F1q6|klgD6|S{<_f z&lmjNm?f>P&$s7bw{D&_*VL$}sMLJ_j!@b+I2c~|hM(W__G~?wRsSVx_?IuX%Z;jf zdMs2_aLQ7{KPoG83JVL%%4~q6i4G19mh@0pPi|@wpbE+dvjz-1padTuUrZb9Xflr} zs{h9M3yhP2g{9!zH*74du&5|*PEI&K0zyJq>WiBjS_URoR{d7*`K2YIZaNAIaSe^v zxXjGV7Ut%aRaG9Io;1|d@bK`B|+kxDS&!a~F)I^Y~K`p3u5?=E*0_jY!`29Yki z{JsAXSlpX80s;bKlpBN)GG#DhisshV)+Q!gZ{GYh_8EAO(y}r>US2Y7H@BaqrA&;B zC11a845#z>-L8lH{_Q-*1lB`XSeT2ecJjM|ev>O&O~k@~W}(jrf?h&Gf^0$sHVX^O z%iy!1dsQ)*GloK_=FFQfr}y{wudc3&is;Qd z`}^qy1)CchJiyDwrmUodk%3`geja}b7^u0`hyOOS8raOLYHG#B#S=-sP=`2mN&eyi z>NeKaV>semHH?2Rnil~}aj?Jt$;>P}JNvt?|9{azX6`D=VuS zJF79<77hv-@;F3$>g#Va1>AOacX_zEXE3+i{N2JcGl^+wcl-Jfj~5aDYM55n zmDMV+AIy`<`wr(87xfz);s1HF`4FIEnh^i_xN18O(%!h37^vRPrhLj%xc@Hr*&J=O z&Wl?u@ozq2d zK13+JRWZbnf70qmd#ox(_y75owc^Hq9}P_WLlT(v%;>+Kth}tu$eS3t9XemA$L@av z_<2-rv|Lg2RzYrHNVey8f>MZ?jg<%ZnuzPyZLggFx^d&x>ZYjtiV~6|2LT~1YA$8` z&B4eTy6t)cJ}d0=fntIMqer#>GrSZ2q^u}ug&)`!DfS0>%atQTGrDwr06n42h3VL8FV6@_Fr|9%ec?;u|bApZA{yTBZyPe)|`?WQdZ zFy5-b99}i&m=7k?KL3v&1{2qV4%7QjOaAXN(Z6DI9vY}B5G-o)`+QAqh)b}D4+k-t z)qviphMk(8-rU$g9ES9Y+%LW+jRgkP)6)Z@-U56dt?>;pnBi7Py91bm7-BMV^7XZ~ z%I4(hed(-CBE=Js=Ro~RT_&Yg$O&Z$T*qE61uC96@G!qdC zI&Sq{?v7e8`Y@~!`8+?~b#{vXdpHc<8A(hFh7?ZX_r4kKR`Q{5-)$dn+hNEG5JL*C zp8q|?1xr$L^4;yNxVShF!rEF|QN#k*_##iDw?IDt=iXMpcU!4cBZVV;I`#ACnI2_%CGuH?l@jk?=A{lZUdvNA#=GwSqDd~po55%0NZyT|R%LJ!~g!`5*AG--1 zg@=aa%#&1%bI;a(`lG-c{@1#X)Yq6mgapx`JU18mUK|I4jDv$ys4V;TEi{t9_-*Ot z))ollLA#25BO|R`qaQ%Ty}G#}B_dkVs&8y0z{DhDVBi-LLL8xj(5jO*P_Nn$}5SCo>H0+LBvyQ!iA z-jG?+d$Ev@3xQq&S4J70c=^YO3y5*`i&4Ov)0!4Ak0 z4Db>Y6884?Lc_v2NWFB7L6kcSnFTkHgzlq{{RU^2CcSLXW zbub&V__46NCco#vp~ynHzPKpMKRbMD!=JUxR346{^>@pFm|yJi)s1CBDK^f5)!E3E$2d19ehnW1nP+BV>iS%t(=6*^)sFb73BCq_?y1kl?V{|;ccg#& znOY%B5PBU-2`Zr8^Qss`35OFP=LFnNh%(T?ZU!+KBwS#hBD{DZ%Jt>TmmG1BQ6vPy z`P=(T`+GaPxQv^wzEF9%UbG7fJ6OZ*K2wmB8|;s#&CJXU`TQ9b4Gn`t5D|A{dmBC) zk5vy98JU715YoPTT&DdKq`D==#US}?sIQ-znp*SrIhdk%SJu-L=H`BpC>EEKO~x{& ziitzfb$iPT{GBzWtg@jYh05Oyt;`9>peq;X7VnqnJGM*HH$%iD2So#at}24=Kd zh(X676i*RCFpU!QMH|s0Iip&j8hQalW^u_NQ4a5e&mJ!>o1L}S-3k0Sg$pf{OrCU! zlhYWFP&KuUWQW{@HR3Tt6V5-&5&@~VSTs5jZ%jf0+z7PR+U~B5do)Pd{r#abE_Qc! zDJdvk00 zL?ff1bQt9b7Y0^Wvx|s`WH=oj9*$=4eXFQYh#(1g6%ZxrgDG~KVp0G_A)dFev=lVs zcfbFgQYOmS#01MooNB;5v9OTtR$os~&(~L&4xJZ2%+ayU`;NWYd+>>-JxKIQG(E?n zGJkE&V2o*_DSl0K)$flG6_q6f9sG1)IhH9>BG7y}N|-czNTFm*)Ik{hOJid|A&)OC zAnOLJu2CPE#V~hr0VQ2jR)&UFR9p=CptP&oK@<2zH00Btr8d7mk;#-I0zhSu@<(yu zd9p6HYlO7@7>TO!oE=fokP8-I{6H!4wyAa|jgxd?aSv)nHKkZ0EB$KGLfHOy!{Vrj z536V9D|YLM7jPc3UJK!@=nt{DogGr;I{Y>+yZmKD??Zlb@>c#B4@0be@8Z$?-6&vh zC}jA`14Gb7kMP`2Rn|!iJ8P2aJ8{Uj_r!tjoAB~!{wkNBth%0Av-w$Dev_*OJbTTB zk7x2yb&papwRaA>aFf4J^&YopN#MkPkM1s>xq!ji(Ow`D7nBQ>PxT-7>mz*kM378p zWDt>(W=BS%sm9qZ)k7zV3L<1fbyZeZ!^rvH2RVUg2~kDsMWD7b23;%zY3yK`7#oA?4Hl}KZofv%UaF#NTsw6X=v0q7g1h`wSJ0WOS8Wvp z!AxVvL%v=!bqeRiJ+ovo`_f!<8nziN;%#O^1!omIdjBMh^0boBpegooIm>K68m|8B z^UN64cGDM>Ps3a@5oXblerAlvk)X0NF;(hnR4i&Sw40$ zHobUQ==kFev53=sWpFN8in&As-|jYt$n)*;&p)?G56CT5+ zCA6<_lRk^%vk{Ep?+G@eI>*0@R-$doGtEL>2};Sarm4l~Htm!%?CQVVszaOo-`2R; z>M#2EUHRijS{j=6q(Ut|a`tI%X69N8cG29z0%ZAH>+8^N(3u+>8~m=p9?Q}~R~)6>(=&Q2o_gka*F8KA8EeIeIk;bW3c zPCptO`89TcYRtW?jI8Vhr94nI0`kTHLEO|{b1tK!Lu_t(t$NCVel7q!pc77ex#*6D zrlwfKn!V^9NUSa{*oolZFz=x{hb@G(w9~qk7OkzTsbdIKdCh+|?oC5+sl>#>Hg$S~ z*Z*alY*d?+RzS@rw<`RqtOM6HWSvm@iA^uuF*f&9%RR+P_r=r0vhpQl)nm6Q8nWiLPd zD)RjPfn;lwKmB#VMa_a^d5GY)l}X=dr>yMBHMOI{jyNe!swhU_Qa;K@=M!^%LK`Qn zrVh`8t=z4qVq^7m^v}ZjjiNGs_iujd>s?F?p(refv0A=||2eaeVJ6f+;7u{4T~{>W zc$UH+hOGEz7O&oM-&wUBOLk?4*V^mM`R8{`(r@Ji7`?$CHz;-`1Wj@l#{-ae zH9v?-;LU&FO$oyku=E(i)4{DfnHz^Cv!@(g&MAu=_mABzuUxO}vOg6h9L|9nB=bS&N zsHg^7pkg&f=l-&uU=MK)RZvS)6Kk@#Js!AEvEfQ$4Ruh^P*A|AyJ-)Am5Cw>pe3`c z?3R6 z+SIpKJkpKi{j1>!EYBS(vl--P4{hx++dNjrH)t%FFFwVxO*VYLw=;Wo!CsI5Tg8aL zQztCHU15j7`13`AeJ~T_g zX+}L>e9UVTG{Irc<3B}5*<@23nkP|^`*z4st!~hj{K8wORI20Sk4-KaEnTs}$7L$9 zN}8sn29|6KKXb1yCu1(YruNcbR!tP&>HD%zrod*sE6MjPvX-Y~+NC>2Q+6QsJ!W;SB$$x|X z1Vj=a1C|Fy2^JodLQ89EI30YP3=GzupME~kG5?lVR-Ra0&6Z4^g=a0gBX^RL`ZY5U zBoh@CCB6VEVFxS}xE*q1VHU_!XwYcz_`!M|wx2$I8V7oO6>{(|5Ciol70q>cbg zTwIW0(I~ZkM%u1~%_SoFq)RvdF&V{ewaVNV+mpe4FZD;$a@hdU055KSQ0RoS&-JCV zu~JS1s>ebYS3-KgQl4ZXgn#Jf;q3Aa5$_qS2Hu_4y8GCDsK{Pg&FhDL>=1@(RK4Y2 z&J6C&UAlS~aV@EDEUr#E&1kx~|=n!4R#rQXdsV zIKd;3AU{&TEih8`kT$kr;{3$xC7F`#(cEQIytug)76^QH?qUeWQ}^gF9l0Y3Dif_< zd=2Kcb~0iwvNWj27_UO_TH8}9Lh~MC?b=}ygX#$=kS-LT!XykXFK6D0>w_+iJ5nEl z`8#!Wby-=&dcQo7BpPvwe2oH7{GxgRRIX@*vb!K{Cn+T7Ff5W{VC3dN8<(fV&D(P{0mfEwI>s;@R|7W!cFL{h>+Lb#^iERjo*Mu)Hr9y``5zK_E{r ztQP^1TF2IXb0(9W9>3o^iG7A6(l>)qGIdMpBe(^7F{Y4Q*y|k%lqX>ZM3b88WwzTYjxm7P$~(U8yS7W3f*<7} zu<|^a-O4t;nZ4(%9=DEop5SpT7tZ^y_5}|O9Z=>pdV>@n9}i^5lzDJ#2@O1eN zo>o$>-Pf*F-BTz>BQqN7ue1hVS)sms)HX&fADi7WE?`wm_}V1A!OGk-;&K-eiiv{o zP;~JgHU`!A#HluzrX!1e)%y~Udj~u2E1o2rW_e|{u*2h;&G;o!?z_vD73qapjn~oP zNI$Cea~NgEJ@l;5EGO(o4bSsORgw`>5X>jTJwt$CoW65&#I8wW7|LG$xzpl#0(008 z;=q2Pec%fV27UNQ*wH{5Yx!Uhwh2#D7gLe{Y_s~=&~cP7vY80Gpc=WplNP^Z^^uJi<7m;7)|#z0;u5{GmbcB|_c#+AL=w}J2G$aT`)E;|J`t}ryFwc^ zJgs(v5_uV4!&r@$+lonweXC;H+1eFL_) zo*)8@fxc?evr@qt{pHekEB;NHuX}%ev{15qQphxbo%}Za!;ni)*rl;!)cf-Z-iv!m zGP9OQ^b6%=?@YVsx~2}A?@Ovp*?ux_@7XmacDqE6nu)0UV2#>_!)1E!kpuqRa?27P zVx}KjRT$bNzGO@<8BAxyluI9!BODATl#|0e^TvKw)zO^S(OlM{wk(egY-XVfP>HdR z&v-4?>R3}-3)(8?a&4}-czB?;vntF=GC#0uUmBJ(eBN$=fsX>4+e#yI_PLZdvPO1# zpM-T{U^!X-g8`NO{VEX(6rokK*}5@3%*vRm83Ifd|AxN0Paml`BvNd z*W9zGx+z~b&B-@5z0*0d<{r8`<_`*fB<@Wz$Qgzf3J&I6#%5Nbw-rpNa`=d0FI!n& zWWWeU?CwQBKKMOHS05$L`(}SvRf~N_)^tcS{+WXOnP%X}?+$kg@2z^4Oz0badAHES z{raz*2pVfWl8R=T#C=~|UsU}3c3Z+|g&lDCt-|&M>WS3c^}yR$E)8uy1KGB?G89GW z3zCW~sw%!MHO4|wpDx?$I;xT#?PkG?aJilo&N5-(aY_T zr0uAk{JQUQqv=Un*)f}%o62%s%9oq@-=lHUNm#gNSU9IxrX<>(LL2i7?(@H+$uTRU zF?2ChQk0^E^X0|u?D(q^%=qfce^6oAY{*3!e zc^M=>AX`jbX(Wa<6x6gne*X{|9X-i%yYN;{4s?eOkB;(m{hPx;y$_>NO}1O*%)3n? z^m{~6xoAqXj#J{c=ntn&A*V4K%p557q6T5oYR(9ylU%6#YSGkw$=KH$5Q6y^Irg7# zA?qdD8nIS!-z{U!X5y3MX1}WV3p#zed|D~EhjA_FbbH#u?oyrjHtmRA`{kp{laB4| z)+#yU@_RauR|Apcdeo`kmOdB?ADM)sObta%HtU-=^GiBlnLD#JGn__|u}T;P|0==X zNbuNev!P*o_~QO-DM81Ek2+)ii4Nz0J`<-P4!7t|#=HwvCk1vMV;p1CYT7$vTD`Pj)aqFO_`4d~|0p>MF~uggBD1EIK)R zhDnG;zm!>{urH1EL;5gnFRD$R@1 zu-H3>P6G1w0CU?n(k-!br@!DlAc7_?7d$)`$*nSL88@S7-U;gJsxng2cODOX3;Y@> z8oUDh4-a?2!JmU=q=ULYe=d*;dLNT*@x)8w6@1d=^p`v-vAk)if?4tW zsxfTZh*6168fi_+dX;)g^-4%eqTT#l40^w`q@GLaliS`!35F(Kl*Kl^kE~ubEMf|I zITe?ouJlfln@jB-ozgo%lMp#3w0x>q+TdH);CpZJG&R|~DIYD9HHF44d{E*y zrmFt%s3h-n+|rwv&t(FiAFSyyPhUiT?r>f;Ugf~%Kn~rpLt}dHc@G^BK;KXQ_U+r) z*jSJXW@Tl8;v`^6x`m3c;KW=vd+!D8T$Pp55C5n#LOtx+$z@CEOEmibH_M--8y;OGJEqNX7?oFbX{j} zU-{@*ZRb$$1Yuq4U{hynTV-uOW!E-)(=_|oGJW4L{=_(Q2eCw(YCfHIF7?;EsMa@G zr@%KKTh;{lp+9T2-V^9GTApZa+16V!G+EN|l`!%bG4K`-2o(0s(+@5%^v}@^%+bke z55{NDG4%7%4KC2Ha3^DN#b8e0tbfGWO4u^RTQ_KN%{pg5fknXXGNDmBK}m+^E&p<% z6OU70PR3NsLT*%=bC6SXXk?6KbC|LHDf^^&EpLE+WY{ZE{li;fe*VUNi;r&VH(dxA zfJgx~?&YN=P^P?IuZM?_oen?n&V-EgF1SAXUj1s^XXYVYePAQjNwn&AKQZTE+qUuw zuAm>Q66OQ1=;IOB*H z(_fAg;7wufYZGqxj$&-R$mn68*JIs^s*Z2$CbodiC(JNh+O5zv-Z0J;@w?_c#J>+m z9nvYqgJ(H^r;k0I81q9FrdRo9YWkXE41=2{Z-8@TfoVNaWj>D&LdCb$7@J9-eDvN` zlcf%SgnNHvOlbO~WItH&a#>sn!$!C}8 zN0%GgC-eITYrA{%=LS;^jnHujVPE}_twDedwA~9K2=uaWKP+yv48R)Dpj#TlZ1GIT zTZ-LgV4QqqdXKWZ$C_o-+=(iV#NE%_@20CN*(pmPsNAJJ%DXVZzl=4(kLS^t{NOnY zcPE05V)-S)JUn_|KHw>pJ29I}KjVi^8k;WWbn2Ea;g;Hg*_O@Myt1R6^_kgKgb*Va zm)ajcu-^|ZETnmoN`dzC&r9K`tKOcfm3<9;^Y?JW*Y3458WpHYw%Cm=owK01k)0$H zO`_Z8oB50?R4=}^8Jd?{T6*{L0>1v2a%Dhi!^Y73*Hy#v*T_G1<*M=C<-ml_mdLpz z+$)mXu#1>b(S9PMtmEM-W<*+|sNti@qi5K%aF+e(d=@WTQ7rm~1YgA&Iv+2@1t(@jZR@)m8mQbwP2BU!&t)$!`w`Fjk#yp* zn+wROau6MK74Ke5wUxVGeohOGDMl0wc^M<9(Op62EU$N#mH|~!3Kc!XMM5tui&%7P z41fEoLPGxRdwS!B;fuHpw>;|Mtu33x#7r(uK^M^SURGASTU$pR0_UKoKP%dMUzvh>AJ?<8 zA^>O8PihsR*+n4`F|S1GaJGwP}un%wObIe zLlwP6g>FWRZVHM2)ixA*st^rwrsKDf`)|WeKFfYNc>m?dUp7yADrvVuM%>t|(~foV zr}`}yKQ_6Ivl}~~UY@Aan7NLhCrq~}jq?j8>2{zM9y6j~D~HKwh#O-oVUV_cO3G{WM?x6mJXHEL@rP!Z|Qv%7u$xpbu6H z8ynlu&~SNqxuzEbKq4rAdq*C^xIHTv%_1|jd<78%9QHXjMe6qMjTF9NB7*?nEzQ6U5+g=IAu62xxTtCFDduWz$-@ zTH}eHmhe-v%vO1V)&0Tu@81;^6c*>_Yk&TnB>78JA^$^Ez_ZE&#sUCoqwE=x6$<_O zzr;#e$>LaknMzQp8??TA%1=efKCsnz zO-fXR?cpix_A>q@=vV_z&H{r0;3p>nNuFZWYEHFMAM^Qok_i<`hDdSoJ@^OIRn#U) zHh)5*u0C*cb3!mzOW)Im-&&Ai1v7MyI6~5+~;>SoueR-QVA5F7e(R19T367GnTPi(#gEqoJ%&kfI{* zh1)yV70abhG?_d=OTM;!9NhLYlB|z#Ubk0EUteES^Vcmu%x}mwR9m>NOX>Rh`t!5z z!o+HBOUu(jtrcik*Pzq^Tu$tfY}vI?1d51+gamZ=K$N^Deqj$%EWH^CE_a+OSP_dQxhfPKb&mg?(`MK zLq=H?{7iLyJ>tE%_Wo5YNg6qYqQ?OP3JQv%ldmty1t4KS=^q6NNlHe>%*U5E?$@!lS2D}DX^`NdRF&={TU%c+1HzCvYD zQGY;hOid*e741Jg*tKFYQbB)$(5L{QR7FJ;x?^|ksHC)%h=AZ7$O;xiUq^7}i5Gq%e+ad{@tPu!&xFis8*4w~B9z#AkbB(zf~QpaRLATX7r0s>7y1_Sm00L9=rLBPRe zeZ$Ph$7f`;8IH#aIq`oDkVV;d?;wN$+x}_(PrL#GliYXw0HNWrTY>X~%=mbF12P)$ z_+yikmm9xO0IN(#D@`Wq1~|4HLId}1-pS5RNMz(NU?xQy+9v>4o1T`ILv(LvX!w$Z z1ii%uAfk45vlA1L^6Q%$iB-{AFjnBxPgh#xbQFR_unS)>*VWXtlcLukDuZTZN_I9z zI@COXFH=X@^jiB<*>q#LqHk<(<0#CkICI!^-m7^#2R-V;Rr3Tm;d=Xb_w>+^_sLL$ zrD12kh{Q81{69P-VlT`wloLA_7ohII$N}sYBn3k>0n8}uH^@Di4d^r(NC6i0OcJp2ying)sa|&j!CV7^b7r;^HO&ib8qSw@(ii_4V^H`LlvVOf+IfB7m~T zt71l9tZRVkPc(_p#LP@xRaI4GWoT$9R9iqYAaK_34YdAEOwV`$r#>y(X5g}iYE8-3g~|{6XURV(0RH6@T$ql&eu&?VY8pE7Su-RrTu^cIkN8@Nn@S&W$B$uhmK0*O({D1n!8Auu%~)*)z+_hd6mD?@0zKvF zCK@3gig_-QW(A+#s9`MfcsrX1q-{W024+iAqZ;@8 z&g;G|h?95jHS(?Ksp6DsA%)Xk)Se_!Hy-r|j^t&NOBp8dKV1AY$i6gYZ)|Vbq4al= zx1Odi7shdO;>Fnv?@#5`h$6_42zzN_lk)Ql_OEuWI2lLIX6McO36WC`0(s-nzUeerI<9EDn zYBYS5iz-BY<{iD?ah9up$FIZWk3Eim{NE0}S~TTdwUL`^HKzqK@8xQ?Z^U!GpYH8S zgQ09nc4*X`LWu9g?{PVTfoj>Bhv$$}JIiqDNaIuc-DXE7G^GjzG2H8<|GN{qcO97M zD$0w(z|pek@UB}eN)~QzW+5L-%P-ewZRq^CVHblDfX)O*25L~~Vu}i(eqj1 z5i#1Z|M}WZO)XQKk+1wpJ+kDdM;{)Vn_yl-tzu_;M;@MW!Ve8(d`H}OKTO>kSv;=p z3@dT@cRuP(2ME~x#twR$uss^$%Natunl1;ZjpCUQc3^4-Z&)z_vtshVF)B~Q#p zV^!BIe=WI6dlq^1xur88`3X0Q7^xfigUE>r7&z@Z$ z$lY+4HdZYPiUkTSMwdlgev3*@QelCql?n+NQ~lU0^xv#1^T&ng-)AuWj5JY}c3Npr z@8{bMH`2?KU5a(3bL@2}Z9)C(5ss3KT6~B~ZlsdS6oRaj)Wvz|>xefGBes0wQlhSa zUbDdXbC_cUZSWPyu0H@2y?VM??OFj`amdSJ3vlv7=dBVwCp%LEAEnj#@^@EDW7dQ% zmSnT{cw1#l-}s^TswTB$WzrW8hZAdB(30)+>aha;>m z;!ut@c_{gfjwyRs9nAcICX0MFOX$J71>a@NI8lZ-)D$0a=d9GSyo&x8r18^)NTnB(n)WRkvKUqd^G(`P0YdP=_C zIs6P1A>aUlrlqC@Yvg2a#7rEM^yd510}Ta1p$8#~=>8<+lm%I*df54z?gqt&AX7N? zkCk><`LC67YGEW;`~=l6;@XCw0yNH+4Z*#86%0cX5;%hyWrAsn>;fuq8s*h`@puu( zKF0;uXM!#it3UY1co#a!M}-_>9)F0lu_}9?Akdqt>9%N}$IY8hQc)TF(o~n(v}C~y@EoQrc+#JcXu~{vLBf{>LL_h1KA1$lG8+c#qXP&8>KxU zYWDX7{-m!P5hH^$Yag-%?`GScul^`OO~R}~Xn~!8ObxJvZV@LzK_SL<+h#1y4^vrh z@l(z#P>i+=a5-YC2c7xc!p7PAh5e58jAba=AY)-pwDY$tL;^lK!@7(v6s(`#Y{J22 znn{T|tFt4<#1w}nML)~iIM|2;HXYlg_oJg@T{E$iQYUH0sZXfjMg$3xN#bijXu0=4 zZ*^ylf8<9g>#TtjS*-taNXQ)>B5o^5J<@eC}B26`! zrN!1Mw=&OnRGovz+Hjc<#6qKP=SNa6#I^~82@`ja>hR!L?}Si?^-B4Sf6adJ{oJYVxL zNED6$ihFQyaQ(~B)bzz}#LSG^f=F%)vU6>5@xhI6Q|4L=F^VuI6Z6n-XONuKylI^o z8&|uucMHI9TIgznwI@1UhPHgA#bZOHMrO8Na-4{x%=4pVgf_e$Jx*P-Ldo7%<7tV6 zr;M0C3DDyJGX+TrP~jjIS;t-iea@gxVr0P2f#bYD*(=I`$`U|bLLl7mfW`#yDM*IE zmQ+OskSQRp-o1MVL>w{>;IOv0wm^x&#Mn45KR=%iq=hdPawK=ujp_T8V!n$SE9OD* zCDIovk}XgWsD6hMBESFkRxBLGId1RGZ=!n)ZRm6fYXWOZYbI-6V_y5U_SNTWp3W-7 z;d{d1Y`SFHBgzX7e$2=p;hMsn!^BK2{R@^CMY?!bUi@ zNt_9ZwW3+9@C>lxm&+Kd<7}CZEWhDiQU1I=Kqm9BuvTEO4i8PzP~2Qj!J}8+aa_}0 zQ(X&94z2*grK6yz-tzgf$?tvz1Nx!CN9O)^K0aoEq92`|P)0K;o&+L=M5l)^9VlGl zwMWG3J@&1WluAKG{gFt3@`!XyQDZRE&?3&%(ZCA{HdVW~q)))LJY^p_e2W-=Z|k~6 z3|o=UgR_@+L3!Za%2Yq<%JHGqVYoi;nZC^1w@4`)$?R#b!13QdE$#ZAU`ZcOC7V(@ zy##O{wezXEc*fVc$BrV-ReQ!bl6M#`MHQAmfBs$*;)jr9)Izzj@YaB1Hm}x~PrA2< z(mNvy={@7Y{QY%WHcgzTY?^JA%!e!oMswKJ1-(zF-RQ)FD=W_v`EqA%JHXYNi-o4N?YCMs1~34h=b{CeQ4c!V8Rsr4rp6srn}h?o|v zIV7<8zAI$Y^pqnu-S>(-J3QVxoYm1PWPuaFOKgqUm;@<{y3N?V1G)^C<>-c620O8h zJ>C6|QJK&>RN?F*Z&`0g@YOs|({h#)FG?FIdKlj`*ru(IZlV(lY(0>1Zg>b1o=R0R zUhwQE4S=FOMz)|~_U85HL&I)VhXGz}PTB%LZAH!3#&qQ(^1Z8MjRH~w@6m&3Y$z}m zUU4MA`E97lXGJ`Rxl6Skm8-r#+A~|DFRcB_YDtVi*7{u;S^}Ox(=fFqgbpjTH zemtIW!+>krc0_TB>4BskTM>cII}Hh(8u}XQ7mxdW{ghd+o}Z&TMX6a)oFJZAdfa7SJy@lnoe7rTn^R1!Sct{$te%JUdSi|7rgI;IGqIrCJ<0XIaRK^ zp-_RnSEsknGlcVHR>5|E^vHRWG8MY<)-}R}sC@o7dvc=>&b`~SIqT}rS*S{x89wF` zk*5l*-<$V)h(675O6&Mjr#>f8<<#4GrvaSJng)QW(9CB<{@f`5pj=mX85HF-YO1P~ zK7NFa+9upep4G0V?A2nqq`QG_vF~$`Bk6xjJSvyXAWJmxj$~x!_r$~CpJ(`lPKxmv z;l*Pr%1qxv2Wz_ntVT*#&GNTh^FWf$-LiNtU1E3>L1)3e4W<7`ri{rKp5MLAA)6v7 zJ(9uqG2w>_gIS5Q(wXw(AnKtrZ4cv++afm3k*cN#YYOe$mxE9ay-%@K%7UH;=I&NH zDrRxINpb_`wi^ZoANcK&>f3L1E}Z-vKb2B{vLl3%OnaF)M2>su_&Fv#IRyjMFDgx}Bt*o|XT`pA0VlC9I>|jGLuDDkQtU8V!0F z0PkOb_lnEO$?^Mh<45~BBz%4vWbE#8UM6EiH-zmL&k@r5L8r1>J%KH%Y&5G`qcwXIIN& z)*{QNm+;%oYOV>8&nPHv*}CLq!H{-6%jQudcURqHo2g^X>as^r0skO%A&HgDP~_~T z@_W>k`uV`F0m^f${tg{c(;_;N(#9QRt`J_ zrpbnWJE`D1{ASZRs;<8nH!30Ac2SPG9j?e)#(Gptb8BWyp-B*gF-`~ic=c}bUyw z-&@M9Qyjz=lM$B>xyIoJx~xKy`+mIUx@laiIRV%~MOBqN;6_r=T>nb^Hd0<)&B4Ov zc<;oB)0SEXJ~Cc7Ijul2_RCQjYN(CMIbAkTYfOp%!Get|!hi2^)URq|8DMlxd!`+fPEZDS6D1*%?FJtT zduc2ZU*}UINHh(XPXUP_Ak|}(&S8+lnWm_&L9P-iZw&YT!LMuKi;;bA-HGXL@^V?P zJzspcKye?IExN5=HG710qSd=^Pp{`QD=-QzeWylyw|pgR!i{U9q;vbx1cQWM52DnK zE7j(zk;N1TE#2GFo6fMg(@NzYE*mI0EJEDMoX{FewK%53Kw0fb(K!hs3^4&=(c%1O zKM;i1JU^-k>~I-wXsn>k+Yw9<|6`0~HKx-HW2^(R-XY9!xJPa~&qz2@_|YwvHOCp+ zserp)j}5Sj0t%WNV#IGo4Qh2wO~H~tfD2H2XSDFo?fl+xlX^)>NefC!C@loJmygtY zht)yc!Cwy{F6W;kr$dnP4Q;1x&uHEAZftSh zi_SN97seUe4TXTYQiuGYzft4J)aej&mvHcFws;ofFXAub0`X=}Ek#2R}B#@SJLNQPAg0MoD>6s1fwB4seAgyWcj= z0r>$ErU-jI)gD-Ui?Tx37_Udj%ZFW0?uX^)GTv(bs9SLEk5}4fsaYfE zWuCYI^ylZNRt@^+Gf=(%(?@}r4>L2s;cl8gz2#U59U)B@l$b&S*g*^GekXe0_V;#|CEbNMP)h>olkf&7Lp%fwJ- z`Plw?T?lWO>{LG6^L$z(xCLb=MTUo^c``WuTSLI_6SK?0$lHUJB-vEy|BS~!WiPuP z67VNe@0Zrr{0dKANb-A#j*Je;sK&af94&QTrtu$jr<+lSf}N}rvx+l@IkllA$iqjn zXzh(io`iR)VFSoG4XLrPR~W&@YQi{#D-#S2M^)|c^OXFcfQ7mRIOXB1cttheXUqMq z@6MI6{lEq5=={kj&nQU!`RXh`yrD32B?e>%g^6D&r|Zt*6Igr-J~#ypUh)t-j=J$t z#!4XzfAqu2%dffkF;Z1GsJgjg)ymALAFl-P_ zlKIdOsz~?g4pHZPBv0ZWO_B}Q;aq1M=iP@tm+&`^QUdc72p;UknWEp@g5c+EY$m52 zuwmJD;VG{DjA6&`WX4{+9p&JVnp;=u-DYbEX~%s(S+f3qK&cdMNkP`Hw>7QeZZnwEe0K%O zvwS7)d|qrOjd?8JmUyggYoT>NaaL;&ogs$Sb|mn=TLRA4vSgh=$W_QWn&H7*dw=p; zB+gNj%|G+HA0X=4UB;aBPNzPs^5J}Q^Jmt(V`5BX8H5M z8ZXz%v)cmm#WpZ2K!MwYUD%28B>h79y!W06j}>R8R=1m`7(#?1-{AHiq9C<8q*zmB zp3x;#ZymrGp&g)I@LkXz5q>C*tl@aj9=fhp(`{H+lW(=r9qfQ72B z1Mj>*Y#QuU5(}C-h;k9P@0zK%4(Q;lw{?(3A+rRU2Z=l9=kJVPcZ`4;xeLe{1AhGc z1gGp&57el_5Z%dK(0&uJrJ9n_g6HCP!NwHUTq93VkHIy zyVdusHE^$U)vaCuF-?(k5;CWh0kbzU<)H%E2zJ``V;0b*1a;Eo|892}b&EE0k&&Po zrM%}gbG3KjVdd~X@K%aQav0*x3|-v~^~k0~e*p*%glc0j2f$uBArzeq`I%Cvcu(HXvY=l%60Vqt8Skd&UK8X>wB;}>sV zBfe(b?OO^Io>eP1)xS%B9qAKX=Q_ZiJpMNPooL}AtMk6U>j=lx%@WUS-*&y2FEU5!pe?)#K;NAVY=n7R+SzWt&=ge}rZawxsUib3b zytG~CfD&??<~{~e{eA0%!0!#J!h4XTeQ~*phzuN%;ig(S8f>jck6OWl zSk95x6v*nyo~zq`m$JW!^g@vhXE&mp&VHVv>5+=v0IOYagi~g7n7%IN-oKx5fE_ziL*lL+Wg2Srv57=q1Y7d+3Za*>%`_VGkMqdG5V_ zGW@9-wd^X?ulBxuGC6vWIPEY8e02AUTsX8$%McdDxUsQ|vQe{}X+-=C_l|YR4J`3L zQjR@T&Y-a#8aLLrM9SMJUhGIp1-41q#%Fx5{!WB~&X%Y5p{Z$clj84QVaCs1v^3>e z6$GuT>4-s0Pe1LXSF8ydMK#rYtA=TRk&+V^?7ZRjX(n=fst_mtS(0d-rA@-pE1}&# zzZ{RPQG-$KIya~Sckp1-fP* zSAEq+tuf05mL@X&nwBDEF4nA)`Yv&OYnr3}=cfCb+?$xB_xAb8E{1koOL=q|8JapF zCHmpt*LG?@$3gqRMxKP`*_;nh)4q3Vn68qDZikG!y?~wwKE}x0e2p5%)^E{#Q1a-XyD5Kr-s+kKHd!p8`Nw@Lm{sPD&GvfF&8hfSGR1r_dzd9o ztiR7lu>`2~(m9FNIhlQaJooq=>hn@i4a zb=)k8-2Kj;#B^HU&tbtm+OzF}XDpBGo=ekYch9pRkIfDy&Tl2IZ#AGr=ChXb;loQ) z6Zhqazc(d6G|*d~c&(BNQyd|UI5U*dD_Kz-=q}W>9549P_8h^ATmA)aPp4>D=7Pmt z{4f>FN&ADP$&#qblk0h;NaP<1@F=K(4JV+aK1AJ?Y1gds8yS3Ne;XdRnBjb4rLAfo z^44Rni}GbMQj(?}C{ts#QO<{-6%%>GaBr-MM&ZLd5cOG78eK$twXrFUZP+`}!s|~$ zv{J_l|4Pp(UA6y7f5=SiMaawsrxjqkE+p4t`_NhXV`6JOWN_00&LeFVdF^_Q-bXx1y zNDKjBBCAb-`N|v-2A7}vsS|!m+3DK!^)V7U&PJmP05qb=%^DXP^=jIPgL~FH%BItixq6bU_5oKseCiO z+HucvUOrqjtkMMT0K;Aw3}TYu;ao$s5R&+hD^6axxhfydY{18FP5IZ7JflHhLpcl| z4^H6mG+tMhEhMi1&23WZ9>fnX$8{~G6yCpD^u3}(eaaNHLvVk?+6;I#iQ>lKLDSlBi zcr(!aSm;*|3_UkC_QzoQrP0mGc`WUK%psDcv@#JE{CNPp93OJ)kIZ1&pyrAy{>4Rg zTxb9CCCx$21v-vtN!)FVVG&=bdcDTw^qh-v5iu`?!$&gDb|MVbl=_?)3BQ5u{^9o_ zV_&vEZhEsM9gdhHabM=_H0=JTM?^Qnd-RWGIkrg+2Ma4MHMs@WwX;FD>7F_!Jz0kb zDk|BI@buxp1peMOU$g&VQ-krp%DGKm}hictY2LbqKL+QM*T5) zGK&3)1ECa>3hISoLs3*VxaU|E&4zq^MjqR>#Fe7I`vAhIx3aG|?R{nQu!#HEGm)mc zt?my$4TUoT6D(m7CG~i5UVjVAN876>KO-rmyW&3SgE{MI)r?5+%XE-?#R5iyvJu?n zWDx}~N^Y?Aag1ffEn>@#e<)0Y4Y~NLo*;~KODYr{u0Gje1`&VF%^{E9o>JNIJxtF) zDB^2JMXpC+V(C8oGy+!49OyT*XxJs@MPvi#i9Ci12gCCPB!(DNNDU6{p8l^#zfWH7HKDjKHp&rbcJb^kgUaSKx;^3Z5gw{@bi+^6|Gf5r zUjfcdE-qOQ6Bg=%%b={k>#`iK7&UcSw$Q$w{!&5j{l$oxSK<%LigqSK%prN@)it+3l{Axy0gw4@QE};*buq;UhjdOET^?~} zRwVF^)6!V~yG6HKGX=u7pF3a}mP*ooJA_QQY2|RzC7pY5pLM);gUzS z4868efj-bZkgsb=b*zp}J(Vzdo@Q8zhy_pUPBhuIUL!_RgVtx2%%m13_#@0@eLB%D zVGvJ$S1|Y^lXW5^7tVa*gAN(x^0pYCNEuT|na%)5*70{-6FYK%noRtR@O6K?9Hw(5 z=F|6NmN7d9<~c&=)kBe#@81WXf6t?CrCRHARPdvb!GAg!V1qw2zAiPW@xe5z{B5dV zdGX}k!ZmaBgaM2O;U*8i-|L2zT9@`&7}a>&+CGgzkU?@Q(m7gc+TFRS|7OTmMI&)g zNRfNQ2vp0I8ku@Q%tJumgeC5aImv(5Br{8ZJDB^T@chu}3}|NB2`lUV3mGUblE%RJ zT*evun{fL9<$!nLgUq;<{Fhep_L|tYYEG%F;NL$!u;(sqTXhk6~m#O_IB`g zz?_C0|AHZbj$CXNry_7PoI=y`Lx+(hXXEMZapkM7Ax`2zj9iyQYiz=P>R|_!^x>1s zyYwyt7f(k^UcZ{L?AU~i@VV^ic$arzrF*&(JND5nY6&?z3*D`SvR41XVeo9!ehj=!#F_A#4_`fFxz;v z#@5oMA-yW#*@wVfx>^Z$%;^{G1@(`2R_blsGR_&+ z4d$f}h&MkRT2eVV(J(Q%v|N|co732r0RwKiRX!yQmYD@vd;&W#CMbWGw`=HFv)RzTv~ zwGJ}Ut$EfHlX5J;slMq$R7yv1dj7<3#EfcMcdxvjQU606o)nj_cJKzd*ML1KSj4JV zW5o->!KNS81Vuwbr(CZzfOX!BjYV6)KEpq6W&PkJT$AQo>C0ynO*34b}=~mfNck2(brRq&nD0}LyLM)AV>BB)(7dm4#U#MACuiAyl47UNd<`|(tDbbEz4GI^ZI$~d6Q~-4C{gS& zy+!Y;#x&DRy+7Gmgi!JD3(-&VP|yoiJvRS%B-wWaaG&HUjIyHnW9<;rqkjg<;Yy^v zx`yR7j4sE*u?g?n?xZt(TvjmD5JL6g2C-0qLcBc;4S+1(5B~k*kW{rfhwF*K;JFU6 z^8+QcsIs0`A%v&@z8159i9dA^ir!#x*nI2BX;s%BA@F4J#Q$Yw`0~(g`9@7uFC%Wx zNgyn0XzTe_0pztFI#W$HAT+B;G6nfYOrTj9b+>~@h@Z}VxFq8pv zb`w}52j$M&wb6ImJjPoH1ep36^-X$}cTT zx3PYh+VX+#V9rp($`t4OVHDh}K1@-Z+E8gx!5yXy+!WnWYafjAe5!{F?qT(grQx^X zsynS__mcH(6LjNAuao)Lh*I8L78Vv+D=eeBYW1+X;nUKfZgEal5#qhqkg)vV$rpPa z9r!-sVvn%*(`aTxguai1F+ncw7}>Dex@{976_5|#l^w~V54l|;*imGsxeU#1Zo zPjr*pIz=_W_{NOk?m~9y5uB4AVD4PIJbV%g?gzDqWn^F~4)l6A%{0By?krSwG^OSb z|5spc3f5uqx<0cn3v8>;54oeU`SaRsXoy`$Wpp`>Yc%-UebA)z#>2fmrLy-8Jtp&E zMT7@x^Z$Mt*zow|1Pm1%@TTjx8DDJ;U6F7IyE<`3nyk_MZw8~rhJBS(LLiN>I)3_X z9#zNGkPoc+y!4FpdeqZyGk0y#(dYzh`xW8&XDF`SF;M$**g=dpJ>Rj|@N&#gpG^0RQ?Uii*;_R%)WK0RTg6mk|3Vyn{x$?fZ?^uPY zODzjKjBW65NA`vI7KaF!HvRT%SOYC4tgth!EIXw}{3P% zU0?5cTlwAW7gYp^l%s$Qg~1PO9dK}PiefCIe~a10jrvBxUL}Ar0wvuT%iPoIRA4#T zg-Dyr^XIS-!84Kgh=N@P$Aq^eb6vvP`g%isJ)jm$6MW&Vs;PN?Ms#+4H2iGiXJhO9 zkhciJj$-F{YAb3)z%lyAwt%g^og=*+2z18t#|jC2r>@bv7hZJ=<&@{DBB+=#Ti06! zg|e%%mGH@D6x7JmOl(-x;u)(REUOlyVAR4X`-U05AwiSjZr+ybaLSXawYGnPvnp>c zpNsJQ!!Xansny!bEUM&-W-0VWsBju9AnGrY7MjyM((Gg|R#Trt#7DT{;X@BoM!{)J zKDzJhzg{k**u^!9624^`uH}S2b4oDAqRf_q@e=M5vL|z`XCCaAKfF2Iq$VI4(NF$a z7qq0YDCKUhAO@p{>%L!zeKV6mt>0WoT8RZG{a}C7mGt{i2A*?7O-;>y@{#b;h!H`l zU?`6Gv`uIKj{RO!f6LA{J#UkEO5gQm-Ec# zQAwP^`Bl|Iex^@&Po5V9?f%;6%}hf)`|_?ph&X1wW3^WJ^RmCHqt}?fJPn3rbQx=k zVOp@@&^hl7{Och}i;7^;IZe|E@6-H}CQ=B<_=zn#&cp#L_zQN4E=i?fnIvK$({O*I zqG4WL|Lf<$4?T88$>1YOVk0|;{p&qREeD&KelUKOe;Y{@JSg1|4OiMO0?sFDB3a0t zguq0t&GRLFCq+9wCJ~ZgE|R%E_vdK~*GPI;HjHqD!Sw9&3lt; zX@7GvO}_0iq*qYkbsc7%-ndt%L653jj~Qg5F;0kL4|Td zq%tQ%h3+52WmbbX#(E)lcANgG-J(o2uXmhx;P*ZCRf|8!};)kA?*u4XN(%h0}C z>oo(@JZ4n3Pj_E0BZ9gc&m&mp?^eH?KRw&Y@5QsTOX5p2TEVxHA6Po{3m?`N@gF*K zxlweH9Q{tObkX^rVgv{vj>8-09b=#z=t9BSy0#-kO zQW{pP+3NI-em!>5(<8j6|1!M9H1hG)C1wRc7ho~QIsMLvSX3Qdn5^pv!GMe^kQ#3t z;pty0qz?frT1VN^l}>n|73(eG{{cbGBXxh0(<9ir%REPLeOd~F$MhEsT&R{Mi(?Jr z`OV>nktz~i;1(}B6dWbxkjK#jvV#(b@F#0?v2r3lGyw2Kf{*>6{DJJha%O*K5|Nz1 z_ngRLwsed+p1o5FAVC>K>b$z?v5;h#)O%e{-Kbzkp_01M8=< zmzOuG7H85ieKMXwBy7Z?k6CK4$6uNI>|ANk`H7DL>^4W93AvyZq!#x+(l_IOF%dvZ z{KM>V18kyD)k2h&(yNf=7Ts0c_y2j}_`EuoTH>o_bF92~D{zfSqeR4IlC7sne|j9u zuCFnzEslU4F%7Tz$=A#*PWRBAP~T3#r!cF;asGT7y;JlQnYNhe;Zb(9xkEq_I`tMG z(-uG7ZZlqRw$*3(^Y>}O58qtN>!{~%X7skWRmOJb1e=a@Sg-Up$Pvx zPg?X|JQPKzOkNcRYFD_JW37W>mpq>N_%_r?N41*n#UY5+c>bw_h>`@F5hl4?&ySHA z8RA31%a2+KajINi`zVtwG~keZi%!PFz}I%Eskjx zf75r4;G4_-kz|93oeTMcYxB=2NZ(Ljl4KEqQb2}S7II^#187OdjgYT@LQ!hgT6S<1 zh{9BHk0;j7dHJFWzIT8*s0#r0{RXd!L3z`YbXyFKZQO%v3RXRE-^uXfIl`aVti4{# zojY1S-mY@t*C=H?>%{gX*Bp`78&g@L7^eP^u%0&}Ih40mI9|$%`M7wZ9mRjWA!ydt zZ>=C|sn{`eBi<3wwVNyI-jnrNVR+W-;IrrVJNb5JKen_xUqM;~S+J8CY{UsWViyou zY9fwmGBn!KGhU0ihDn3nrNN|~A?Uq7zT?70g2z)+a$?BY%G6T9)Y7m6*1CU|0*cqL ztuyB7#a;2JY&uc&P#k$Mj5i-?Eu!;l?LKPJ?3B9clv&6B=4ieo&Bt>=!ah z`TQPX@vxv&bVp8Lh1{?z*~l-7JkFhW>RRS2VHlblDs&kNJ_T8CIseT>t1;>s>as`d zAWV%xwbjBE?xf<4s1!b#eaM~in*{YpCR}6>C0kuv3vO&H%1kt6ODnB&AdU1Mzuy_D zQFa&yzsR&?UsvEx9I>*BU7hI*B3t^}Hp@2;g;%byEqgPe(lFHHE@(hBnE0stz>Bp98F2p%v7 zqNiW;x<#g2oru;s7R?u$kJsTPUM(tw6{!KRM{k;eHZudcO~A7D^>M#lt7fLd#oR|V zC<&ht4R&KTzkQGT37}NT#B2;6yxZ7mv7z@t5nzb3nbBop6AMUG9_JASzl)rm1~oR{#dc9yKu~qF|glHKzCaYUkBvP~+uB_poI&69?T|I1ocEHSnMc zPSPJ7a8cIC4!yika`Sp!%`OaBVa|?;po}3*Nj*wP-h((sq30|GY1)FB&swt#ob@e* zptARgQuppt&<=`@Jx=N<=&apYP1CpGijQ?`Fo=5GY!cD>Q#kchg~&eB-Sv3k>Mavk zl3G#_`qSQ0UC6Oku`@?9s>Q0X!>BaUCii!8PE}h=a(HN4US@beW?cYpeO^g;d}>*I zY+HOydRaFD9`nBr0HJETbnaCMGW@ zVaO->X4;9YjDjAAdJw;Jh_-8pxucKCHPctS^UmqqaL+m!p!>?zYu$O0{3rDf{?<+Q zvA-Gtf~&17LiYW!2X`BAb{2&DE5s{gFl4Y;aL`cD5cwdvNulw9$w3lY(qht5<6B?6;$n6A*CCEhyO29gw%ww|s?78yV>3uzgHx+Giwy zluhQ+6K3zkZywde3RTula&0Wx|K6}>NVi{Qehc7V=pF>_|8_n~sO6XX<_#|pFqDf@ z@mLe&{l5OiCE#(|@c48yQj-XgGK<7ECpfV$;MDf^$-LR-O;{rm+`{Zk%e|dJjE#uT zO-Kg_tA(LcNCD%b@B5{i5Z+)dY8XNON;wWG zPYG2X5??(ioNieR>aANl?~isCr;8@1SVOJWYs5t7fvq*qCABVREH567G|U!QN@f?o zvJB(%xeKctd~JTSUz6b=A_X=3J~c|N&c=7MqH-u3$r|Y7VL$LE@*aY^4-4o1B}MI5 z1h#h3h049?*C#vg5L~NMHz&3)8m0kb*TLdDjg0yZ-ScZAm9sT7EQ6*`u;5L4O(ctu zO<3%k;MeU86XD~~XVR1|KScpv;a~OL*p;*Kv;K{bjf;zi#fQ;ZDK|4W*`E`UZ$ z?8(nGcF9{bG*pulQxnsZ3X~)Tt#%x>}yWc|!uwq~&eZThKS$Ml_ zYN*;tNl8fiNkjnSu!x9gplBSBkQ)ak3>pBc{{`X~N4QVkD~O~{!9|576B-f}5*#8P z6f7Ac%8snGGdg-QTAEgR^7jNwTh3bB+L{j`6$vvHi7+QE&RE7YT1f+QIimdHMyU)4pP!|vyA>eU$0Gu1t3s_j?FR!;Ao!t-- z5R^wocIT&NrWj|N8e1E?L&P=y1+AGN%+E0|#y;+K14N%-CL*w2pNx=zH#zr2Z17d2 zKa_nMo8hKEkg={=B8U*uI2_O`=*il%(wI*~rjO0N2@(Is4{v zk=il%{2Oo^1K$s`C#c#E%qvOPe&KQ~;(1+7=G+>zmtr@ZDqgwg@OnKc1ju@~R#%lB zoRPJ)P=7HyUwwOc;Upo)V)hk=+8R7Sp+6|%wwWTc zx_$$Bo;_j(J1O z^GMFQ0?|_qqnSg(xfDbuaY8-jPy(_l7uUDzi(~pNzgGyLCRNZq0>rQM+yq@8+>#a} zm?1baKWhiU|ML=JT7w2T)+Tvq;Vysj(Xssa>3#BiSA$O_-yd>p-yY>(N8Gts^uYBspYY~jqch6rxsUD+_ zAc*opWg4Dczk6|0mr~yebnUaQrjL$zy`c<65bP*|Nel*CR?yruuLKW4Sjk4zhq_mU z!~E&cI(Tz<88F;pHt2U2YJ`ds4mD+9552qfNBzabv^>GDUrD$bD^8BUIdWzN z0}2i*&DU{mWkF|9No!cn7y&y+a4vJ|#P15+M%d?OA*cx&1c1gmF!?+)NL z1^`J}D7K%Rx469a0|c2UT;4e?zi}nli-&q~>k%^0L-`0^rLbrmQ6Gv<*>Max&9ssp z;jff#2COFUQaaXE1;JHKOfkGdJnq$dLk@p&0Rs`O!`8Maq(1ED_h=XK3R@2_MuDDV zShK1MV&i;9%Sy@&93%h0dD8{L?Yp1`^*lw6tScFSd1-SQvaksP_B!5Kw1r2K{Ob9d zmp~g5equ}ILf@}*bAZkD=NnXiP=N)O==}a;5OQF3E}~3t|2t}0k~Z~BX-1Y#GwK-r zJsVOT4VlO7ynaEBQvjHpS{M4I@EGbd|7cI)+YO_MqjP6hqBS(dPeIhi%J8>w#Q?X7 zx?n81*}m#yBS{*1*c`;|SnvC=dK!}f4-rpqR4|y%$=oX8KfTl;erW>B7@JC%EXn(J z_4=KBb;h>6*BR%d&`ZhEgj$;ALZP(!7FmNTp z0hwx9{-(<(28w`6h|1ABkZmcw_ap|#57Uo*;9F@}&U0gJuLdjl+`McIl}BJqZCi2S4?_RfNfrQD-LU>{{9 zni0C%;E|rOHH2z~mh^*=eDOP$b&P>rK+rjeil)=Z@sY{M5oKS6P#kMg{PN#`l{pZ< z$!8P1NRN>q=!JkeU8ZHXSdS7(^qGg9cU{Rx4hi|b`$*;nOS3UXW>QF$#xJiYnyFj{ zbWSgi-pXsw)9wwgL3KgT%DAg&hyDU(MSw#=^m;AN%>Dz9ta$wb(ORL76Ao~E)nHEE zY$ur!U}Kzq+*m!>QQeX}d)6T^@^PSO5}FQw9+7`ss%zodeE0;?(#RzFC&|LP}nm0{>H{_x&F>RGi$YODOdo% zj>{yTeJ;t3IbO{L_KjT-+QJLto80m}4d83q2IK=%g0Yu=I*Gi65ysk_5k0W_mAVFwuPA_CERh{+wdNVO2Ys40x_>X_3aiytBm<) z;BW}hp-G`hfNui^sk8I^#eXzsT5C)JcfuS3Fmzo9>Z)hqrXfCOkv?|(nV%=cCBR0( z_lDiZln(o0e~oP9bM0IC=N<-oD;Fg+!;mn=9Vq_X>!Z6O&VYE}u_s2utWa;;<5SP} zTQJjJI)$JZ)ndy-UFl`l9l>&}vVXTA>~Q!j?{HeA4My?6-mkhpkWWkIy2Q(|!O*D1 z;JN%6YA;=Hb;Z=lE1Q*3+DL|xS|$|?It z2V4v^&Q}|&1r(n7Ida~uE!h*Tew#X?;Wf=0X5E4w^}io~r5kAg_yphey8Mh+H(wgh zc4{YETE7&@xb`>JbKLWhe&AUGofAGA$&jw;IzP82eLrZ%e5>&SqmMd2EVXZ$h5Rh= z)!jWeCVg+6Ja?i-Pv?`K|FOGoEU4?ZZTDb!`sYWp$LK}qq0p%9)Y%f>dHq+|%^2!#&o zzA;489!JJIAe+XIc^N`eF1XKWvdL#WD&gVC#@3_f-Bc013%gex*Z|;Ng0`1}0pX%b z{BfI9ec?>K2;|Nb4Cy&nSHh!G5}$Vk#}OI$yK&0p$QvKv04IdKpE002!vD`yA?7s~ z`>2Cu#nE+EY^D;oeB|lyGYQZ!tyzNkQ=dvxM2*G)z)kCqKM{s2e{?i(_X`6!nK^5? zHfddPZTpDl*P<$Y2o+hxJ?^=r!mi;oaf15PRo)m<@vqT45ZLztRM&q=(^8~Hr?#T) zl7Ic5Km%zq+gczkqkZrtqAb6J<_T*r1`BM1jr@Zc&)Gm;lND(u8YKL0`cwCX7pBd2 zK;6wYft3d+2{X_1gMO*($zgV1l(S|mJV9QrCNmHBPv#E~!~C!DVn#jc48t7H*O+RK z1fS-nJy^&n8YKm!*c++lR{YU6EyvDKlZsuHk7OT}W;OHeRET@q7CX26o^2*L?DH8E zP&`{989>N1j%JceJ=`ZFp%ZRbOC5szzP^H0RYcP}cB_|DMU`H6K^l*(nzogA#)-W#+kn(-~21z$lWK& zKHi*o)U6ocAgpuAjVoyb)gzc0Eh)O%C?IH(QuwxSX8jrvt%+HIk#@HH&_Zha+?*9$ z$pgFkw0aOxq_{a#MjKD6jyHTQDXx@T21;_1>?3#t7AArFgbOvH}> z>~C9Q;9h{1t{UEtWWARF2Vl|YECJ%`KckehK9&_j0o6JJX1Qo)KHD$IGMd{A+z4)U znS1-}A8|*Oox9WKLY9@gZua%d|y!@u1n{aKYyrUaH4{2*U=8jm+Ja61! zbb!<24vM(+_+WCY4y}vVGKxW5G_Ymm9w`@68bY+}DHT6%L1Ox)A{rM|lGwCVPVg*u zWjLZ!WTjb@685ibqQHhD9#*uA2k>mDJN^Npj+ru`-PbEm$v=pfUGI&o?+1JVqoZ3- z*ar|9L6Gc`X9_}0XEOK8$#UGMaA;MJ4`B_9lt*(6=&(O7b#ktm?{EECmNADqjRa^Lh<4PnV?X+Vic&5-w7K|;PN zQBvCtH8-`A^rVmiBr;BBG2gt(`brMOjB>L8tzrvomdzyD*USFw*PGr~P9U9HYmTRB z{popmZj%z2x}NOG;+P;9zpB;sE1)y|!{v)bPs=GkQ#i`hnKPA3Z%R^_If3Pa>-OTFCdfUFGA&z4f8JKf&0yu+*pV=j9Vo_dN z6)}ugV;c8br*lAV#%F4M)%TQuA#=QiAMt0(p=)i@t&8g35no=>c+WMzIvmi8>ckL~ zCk6lC5=ut{tqwt0l!I5)10mQC?#$t1M8VI#4DutKBd3i|f`}Yz1j5R$DxkS(-x_gi z@(13|@U*g~a*~c}JSJgT4_g|_QHjIJiXNOx(C!F;*a4^*v=Rrj(vx%Nn6nzmESx@k zZ+Zk@Y$u$WNGri&14|Om69yF5(lZZ^P&wUpI+7V0dTBaV*~%PJU!N--gNmeRxT3;c zXLX4F7mG}B?)7X00BiTSMwQ?d z(T~whO^;1L4p{Qi;bAhsvlk5&4F!ucbp{Zo&Mm7_`A0=m4=^-F!yiTbcc5hg#1`F} z@0o_3S?J-n68FE$sO&8h_*Ua<`O@0C5>nEw&vr{^7g|Zzd6U@|qFb0S6ct&5`w?aN z^aSA2aHRqaeWK@Tn;@=&7E{X>L5LeJ*vj+q%hAfvE=b5?GN4QazT`-l`Bv zsJsE%}Hr7VU;Zl8A%-o8D>F+jWn|8 z#+_`nm!*x{vaWM*GD%#@s%R8|f)%8HM&98S-LBsszX(O-Gs+)_r3 zVqX5V_}$zsqNP<-HO`bIJhd+`|04 zy+iBGXjpF?5u=}-n;jyjof{@2Cnh7N8KW7cqM4*3r6i)HS8d1d&Em`S%OjASKzxo% zDb9?5$4$-#%Z^ya3|(JOk$4c!zlA z@VZ;-+i(bP7HaZn5h62XDvi@qjnj0EbPeG2wbfP3dMAbh?ba-}Hlm&#&D*9-a?b0G zZl^C@lcj<35=G0ioRu!m*YZ>Og^p%7oj*F8I-6bmT|j&g5I-k6AWAwwYE)pf2&~Kp z2aAhn+UdX%g0HV0FM)w}LjGewpdEY^^4D*dI4mo?DZC3mAlt9hq?j>9w$a)LXCBz# z1qaq9h4?lU3!YUMCHIO>{;B_>KuHoaS~JAJkq|p=+M=}8kNU^l-G0i7fgNxIkZ_hV zox$AAVF$A?*1b+t)fIQ`Ubm1PC2rn06=0b;5G?}mmzA;Z%14klY_n)n6F%i##DN@= z+a~8!QQ|VDTPqFhImr0&o!mE^O*YaAQ~l1Y3Q*&8CQu#~nwc^nXUsV}0eBHu= z>+GcU6u=;`-_kQ6w9-P?)g@NfcknLu{@~qI1LT@Q8Nsv#0Rx1haegZ9&^qbB3hTr+ zEpIO+=O`!R-!Dh0ou%5`{?VPi(w+R$xV`6Y40r^eFC0()Z(SIY1Co6|TA5f%iC9i7 z#q@K9aDd@Qev*=lG_X!joaDi4&4x*?vZ&SuRs9b0>3G?^v%quZ*&X zqMV=%9Q0RTUk@DC1w?iL*&RTFM>$QoWx3^<JqMe>3yvzYWNO(9nP0h5^(%X=?GAPe7!%yhT zsq2@Vxq+D(My9WC)~J9Qt!Q^Bo=Ai$lqof?o~jlhk2vv;g0qV{_2UdIPk{E) z!&Jt`bNr`_BRu!%f9NYYZ#lVA5;6gb9YWOGC9)4x6mN)OHDnZm;-PKDK7hvB7#K;~ zFR-H&+lq*!8VcSF17!l!hhd1=xXSJS?oZs)iIhuwoh39QC+HRtOFM>3G|9ZWNpNk6 zd2Jo(R3Fpa5>ej}R?80C!~tK&2HnVX$H{ub&b(2NSl?G14d3}y|C!R&E5QrTReB7~ z0del>5AVDQ4}#jKY3sRm^~Q7ctz9>y{RbRP(Az&bIPl3RaC*5wI{85QDUm89Nr^R$ zArKj!wRxp^>Y>f_ol}`7*O?pf3wPq#FXit)jQ&Qew^BGHQ$xpo-x-t?bA~5OOemM8 zW=}e)hfu7~5s!CV!_4W1j@_H}neqgGrN06I*cP+4`+#9z))3!swDiZsW zqj9BiaN&wZGAk`$@JG@1G3b2}DelRP2>so!qDJC0MOYOe1ChYGj4B;AyPgi}(vmH& zIFcz9Q6JkiWb+j50|hkrOP|U2;}u#wZ0x4m+EMb1c~e%bL`3(CX zD@M>2Ty`yRzU1Y&VoO%h+oh!R__VP(NiDw0lJC(vC@-E2y{q-p$Pa3YZ;UejDdy-5 zv7xcE&Yy#PRB@$G4KMVngKfpQOuOMV{N^o)@jX_tQ;Mz6aLdnB@u#J2$A1*-ICunX zY|MNUZVbByYg57FjO3J)+anEbvSzebQG#%!5)sm2P>_&5pIYdsvZwm2%a~p+TsDQ| zPZfdb%M#xQs#=qY-`WH8jt=`&lGdEC%CS!i*bhxQ(@!VZ9ZS#_Sm#Q5G)mxQDx3rSRJ1u3!o_j z@AAG)&=czozEExc&_6D~XFbPaIzJLI*;i=@4-g_yS~3Di=tRbBT2BM>@LSkfFE7VJ zGImCSkY1LJaj0T|z5!yhfNBNhCi?sE~QjA|Myn?J_2nx1`10rrt|3k?o*l)Fyse@q}EFCVNw1os+ zJ;KiMpa{Nx@{D!B^L^q%3!4O8WeXY73;*h5U9gM`B|Y*~T=)t7h5Aq@H@%=tQoO-+ z5Js)w*W|MvC>o#1Sb?mCN5n!jGtYl)y5^?=vJ-{sN$zO}VXReDk2V9~QGABKx>B@P zgbdm=8PTAc7Xc6aQrMxmeXYrtuP#bWT^(q){teG~2t;n^=+IM9+1cLieeUM$79|Ci z<3ObPNgN1-fJ}C(-BITIkinbmg5fo^`+KZDVnzs~s8{6=9QG3DCK)f<2f!1iP#5WxXZSM)-81uHE|Rk?l2 zzTK{m!}6G_xtQz4mxR$=5ek?GgnV&zj(T}L7X^$0;bUbbC7{T%3iOXx zR|R>WxrBs(tGbSwTI2sb_m`);|9kFJ3dBSH}_9Q$sN8j{M?K`I@@uoa z{rGCA?UgMaFf+G$thp`;6lxhC$z(@O3TZ0jMH}Xz0W;$Ij4}4-ySq7=wCMQNL#ay* z?ev;T$1?Ul7rw2tR5Tyjv963J<()mSuIp{ivV4ITsr1=T>nY(9%RS> zrE}Yjj-WfR-2{r~pq|w(WAE_LKo+o+1BlJY3W}70slT62&Nj12C><)xNR7NBNGQE- zyAfUYOuej6Y~7g~DtqLmW&M(l^TRTWOUbH20>LTuPEbkz*SV8e&9Nzk_7kgW&pqnc zjdA5eoGHeu`Gd`hy2Lb?hk1 z<#E@DHe!r!wPAi?z6wL<-x$0|w+0{4X0qlgqZAP^wjSkj-l%VEOpJ*!G&L2ci~|n_ zXc8J41AS?d1<;%HhhqaNIegHG9UTE_V<3U;pZc2^Ic#O$Z2b@EHC3hz*7TrMlN37 z<&6y?LBS4SUQ6^D3kwVMJ(E7l$}<-ic4RSdwvFnw8_T1*j%60agdtoLx$)0TdAA zWu@sdhUp3*h@x#r8f-hclKEjb2teX-E~NDh`JBi)Pwthgclu@?gO# z>GLBbpUeLIILrV1l%wzSkTpXU-6b+(C+_Ect@v4G<DSg6oUTwGz4E=Bs(IB4y`J#fQGBlx-2{hzr)e zV1V;_yTSxyn9ADNi>k_|d(q^Z)y|K4ARG_dJyME;Z4$boU#g4&YhJ(8Txj*PK8t1S z$$91UGdVP)*5_tD#m2RPgGPilfF0JN8yV4iT|XRir*K>p-h>1|Q-=|_-$3quFRj(~ z9a`+KUvLNreVcaRga_sb)u;zw<>cgqg(3c!y169>&o6d8Jv?-$0iD6Wq}c>~e0-Rg z2|$+$sL_D;py!VB0%f*ZU{O?674zO2{Jo{+5nH~hwsu)%rCi>iZbfA!76!)R@^WKC zL$@LjfcYDG;e>*Q_73GoEYN-LQu6zLdwYA4nv?`QieamS-R!VNRnLq52fe&p7VQbkT|z;FRdGA;d1FcOcI>i#krBm7PpVd$?A|j4m)0{li#Uio?~5J^o>beB!B@ z`mp!|x8@K81q4oH<3Ntvl~ip(&3#qp%MsW{RQI(PP(mZh-cmu;2R)X;!>6}@Y5WZ3 ze(mn!_o<6OtMr{6J>aQgYx5X{mN4=kA>dPiNg4oT=j}ytO)B+z3uR{&G0dxrn*<*+F-<&2)e4G6Y_zgzjj- zVe*4p6dL4z08vXeI0dBpt=y9UFSE?GtCv4vws+7L;r|!}?AWM8`l$3^eY++ofReB(Vb109gt7 z6OvZ9#eo7y?h6X~7#MVIZ0PIhtplmsOeEk}(4{vSyV7I}wcafPL^AR5@%`&r+1NIx z?6wCG@2}a&oAv>lfaNq!qmVfKYx;@Iq{5!=#t~c!*~AT1$U_O=%AahNcf~WIvu~8* z>bibB%|GQ3y!ea2*f`9705CEEBKF9XN?v=M$Ts@UNvW9HC2y?IlZP%{@Bx-*3c48P zW>Q8G1#z@+0Lmk`Y z{%!;ML5^<7p=|gHG-09{3TYKB^=TOdJ$blazm{+Q+wMee8t8darDge}JO{ulY>MyJ zCg`3K>Vy0gAMx_so8{j9o>hPB?%-lvhQG`jkT|M9C{nthR5Bt!Qk(8_H~MT1mQz{b zF{$KG&yW(CJkHSAxfz+}Uj7D*U59rqj2Ow)~P&EtDfsFYnbR+fxo%@55ul7xg zi1)D_eUt*xy@32?k{_8uE^K324vD-!^aiwL!J+v11*AR`9-{2U%|wl_xnVV#a9tEO zG$WD{(3PAzDTvJ?Z1KqN(Lp29fQ5~<%YpHoaedN68b=n5P6h zm3vDVy%xuPq-uDDr22Z`Zt@78W@JPJP}2t$mA@xoy#ifH@LGV;85I>3uqKCW1A{^T zT!=;p?yaq@Ki@8Ihy7a#2nl<8dVpyFn2GR$)Uvz`F>6ruba#%8g|*)5Tn3E7*B$Nc z#l*x8kB;np z0zgrqB&Q06PR3Ivu3gm|nPZcQzpiQgY5P+1SOsL4g`vRG3I1iM@G&=qrZ7UVG=dx* z0>(u1``281Z5!&XMyX3L{n%DkCWUv;+$(`q?b}chFceD@gjSwJjXXD{QkBeLXj|e+>R{ou8a? zT0}D5J`=7vOb@co%v6K_TJ*f`MkHzMZ`%;8mOX&zIRh7?lDymz|BbE}pnutRzP@{g z4V98&Wtqhi?+wcODVCN4vg7OPlWmY@WBm<$eFMB-(dBgdS`%0AK!J-n-ybD{=7Q{c;*?8)SL4ICJks=F--*3@Aur#LqZ{5KC%i*c? zss!CP`6#^mCO&R!REATVm})YS)n<$W@O)7;&(bf$A~DajE;iai>#4CKoKbYYF!0$Q zk@Npau|4wnc+l6~-3?mBfq}x};)BcG5o{*CZ#xZI+S+N^bfDsBZg%_ry|S}2i$&DV z(vp^%`cERGACWNPz2~=FQ4(74nxTb3-*kBjvjgE|;DZCf0ge`6!s3QbNQmeJ!(C8b zJ{p=G9T^Ez9e8)~z8P^{LR`GZ#f^7yX{np8E#wsV9YVNQGx?FszK=%sW8j9lXE)t| zu7FuBFDuJ41dOFyS^8d$YqxS?`MKX>-yzs%bi+jSEMtvM793AO^Ygzunt3WoY^lV1 zYJoj+`0$i@%2++QpS@$vTkyADt3?F#1qFSJig$B5tw0umWPDjt%S;ry;gg{SGlx(w zp0J3#a`jOmwn_tw+~fjU1o8~bI_;m@DMfS*qOW6LS@!n@1U9nThS;xgTfRt(X;yX& z6vtq#rYK9)D#?J296;B7f?_wJm~fpBd1|G_J^UH3>s-s4Sj`#3?Q;T_wkUW6=_CMX z)s4zE5EgCg{;u}=Get#np6~0^*mibzA&xuExB{`W-$>Z)J%D=%hjg!(`f=D{*=4Io z2bfE1HdS(WM*(ia!78riUiXC$Cm~GHgh`j!Z!DAmuc-+Ot98m_Jd08NA;*k} z`49>32ssyio%vn|gloR&>S81WP zY`ngo-VFv0-}0x_&uR5vyB<&U?9tY93Wp(4@DlpM5df9l%KYn@W}xS>#^sX%y$|Z^ zRv0V&(}BOEp8V95@!;RhF)!si%`V}KhSw5*&U*(r85y9SWcKR!%IxQz(&Y1BYjd#( zA=&D5h-SRLH93jGFIISvnrG$WlK5{MY%NTQ`m_;ILc}R3#gC3HK0=DC=sAu_vhp}S zhDNnQkR`M@$y32-mRtb{7pRB<0v>?Wu4_MG4~BKJ_D|#swH|(uI1kE`kNsdDyhUbx zn>Y=4t{+)-giX89yXaDBpFum=Q8^qq#`0e4bFnRg@VDSt>H5FBf|y%*g{57+T{Obf z_J)lho2a2oq;yXaR(wk@H*w~G2C#f_ZN=1^boQ%w9Pio9FY-i?k>&0C@cSK;S*$B@ z>`O6|AgCP&VVT*J3>4khyq|S-tV0_?S(}TH^aF0Cs`FVMK!e#|Ju(H-;jR7ae#h)z z6wfM)`S1S&P^Rw?Hbru;o7<%)XC%nbOLHhwSFC7)h?gq5cyfWF@~+ze<(wJs-d(Y| z$B&C$IHQNXeoanJHAW^dgdvD{rwn_)a^n3APV*Oy->A%3O%U>W^8e1_j88k>kLrJK zUXJH#k%%TV+4*Wf2uP8@F9?_;vHD(c08vZedGt1}>-0nmpg;dd{L>=7U{kCAty^N`-;Il^^r?)gX+UoK@+Cu2hrhk8g9T_lFj6K&&DbJpxzcPFNkl z#Ic%M&5H9=zCVw^LSA{Bk33RuiiGiDjLUJ@HWPP+ZheNI;zz6i@HJ1_*fgupN3<_E zD*))TtSt%Y4}uGcy`LEel^S;L{V03#<$d8kkP^5cPCuEJ#Dh8kC~*bVg9GYa==7;N zPal6GTEO0R*;8`ai+#V`!y3B}^eWyikKZ(K!e&g_utJozEkA?xc(ZWzub3U-7d|x`h(dRr@u!sU}?fI@Gc=Cf#uSG#rMJl zkC>P$Ml#~yK1*Cf2&>Q=jgDoW`y*%Sp&a=DE|2CAK<@fO?Rmy(TZy87x-pbCj9N40 zEX**?pG^PKRlhXCA%@&dL^TN&ebDR&L5AB5G-xE|7$%|~?J)G;I_?VYF!7*Qvyo48 z7_R%EiLb{dMSGfl!4sNOdS5xyO{TH94HiSL%fE6Ien{v&EN8_VSZr2x2+&*cdVe`D zK5$!cXV=#f@oGvyv(C-)wevneSF_`)-Dm-pYt?_kFRRox&t4v*nKMgDd#EM;`>8bp z+g#AQ9)&B@!|kw$LRP$qgRHz<1VhjlV!c=5WHVAYENT&GKf-w7h0s)z%DCCIn+%~i z1MoU(D43g<)-~;*Ca~JpwKDYCM0moTt{c^j{di~+CDv>+1ek8C?P=i&8&v4f1iMDA zhTqM)Z%QsRvWULq{bu%FxHDpYWnpRBwp@%%j^){7P0MUgy18X#R+E;7zmk2K3OL`d zkSw=V0A6!WPV86nP>cXc0t)j8AY0Ni5%e0Lx}vvKvIp`YAd5|jEO((qM!#E7RGK7& zQqne^>>WQ{iR2DL5|6OUJ>4y$c1pgl?#bZrszZgD7A3+0;VHn~(i1GQsodYMV#6-W1|mzKWv)x^hTkb}#O z#@gRZyykX7n0B&Ky5crHxtHV-=O{AdlO|)47{kkT&y%hqIA&#Ycqrm3$|W&kp=G8WkXr?x|`$TPEbhyjczZ^@H=O?oiEN{XanT+}HOl zB?b4=6+dAliHVAmGB(L)8`yIM8uRe+d6ND~eMbg1IUSeizsu~)tBwYH%6a>gL!=h8 zy~C!?MmJvj|L5wnUKv&zp|Qn$tbdmOtQq)e?~%A_<=J%cg4VqF+tkfo`eJ@86RGTH z%yfe6@6`4q&mieRhZtHdykDhlgJJ&^(DW@-|V+hH-GHjVlIoW&d`d65eI zqsQzssp-NT)<^hvI$iTDEr^~LhXmi=6Ncv2VU#Tomo^zNFztz0B@g8^~W_XRyb%-kA)l%qQqPC<&N>-=+plRpwbt7ihC-qHf>FCQ~T$b$xcA( zVGq2rf{*pkIv?l#s|>V_ zu2{q+)ca$;P|2ZIZg9Rddfy?T4ds4~{)XwG`ut8V_yO}hF$%XE_k7G^;&=5AX4HBQuX++tXcI~~4ToF~CWx3w zAML+CddIQ#LwUhimpIg>kjf~^N=-jrrswiI41A5Z;u#Sio1q)_*1>S2-Lgyn6sqOr zX?ZA&cv51PNqjNHhaw%$Yno61fc7K&hf-r9v$olrSg2E$|KfLMRf`lnsqhVE+iiDsBXv zhT}TLQFiDwB8@A$-RMV~GHsV?fbVU$EspKCr(@uuN3{JyC7&gTiGl(*L3U^CPlnM^ z`fD;Vb}~_pUt-RFemf48ytQxuZ%ows_KHBEQ4L)Yb038}*^c*gR-wY8Iw7~^V{?G$ z&ttvHm&eS3R3zRbiF(4BRguQTM46A?tF1`f1XJql>o_V{Lmqw%oAfY0Ya?BkCkAuO z3&{%TdR;;V$|oG_BlJZCXcv&B;6Ar^OE<_lrS(9eELMf zgyG9u;nCXAT*SIPJ{9;bBo_~@j8Z%MV=zqs;icrIKjlDE5g`UI^cdwDyG%IDlOZdgvrN*Wrg{A3UWTJmZ` zN)<-s{vI4@&ct-hG$-&HpOUDXWa|ZH9j+}tbVg*?7FbThq_%~{rpJ7~>Iv$aw&PCT zuyk^-=*kn2^ql{dF9b`qWWu&7EyyE8OPEJk+%cA>UWFeq>|BT#Q ziQ{KoQ&*?1rpC#|#bsjv;9{UNJ3c!4^HxFvcj;U{P$Re&osi*O+x6X$^KOC%j@h?P z#!TPR=?!b&9d?ArpKVcq*RT36%%jd1JG)#~BibZBwXrlcs-G*jYkUY^ zGF?LC&fg+lMvqP9L9mdPQreSGwU}%gE!h!jgyuKOD9Xb~CIkn{oyx3iYJ_ZX+9h>8JXSw&FE7R&l(;_c;)KxtK25N%xH?jMoe#3iHzBzk~H zjsS_>oUNAdCDH{U0%JI|iGQLWKIWa;p(jq=iE&L-ZCw=ayU$o!xMK%{@1V*uEoqKQ zpd~O=P{uuVRM=)Lvra}o4W;?8A^(VV7Sisni>uahb3U;J|8mI&D_R>?voX{-GPy9C zvNyy;cksnNq}51Uc@g<^%-!ytOSizveo>Z>LymjTYwCNgGlMk?T%Se8DNPmCRl%Va zi%lI3)zM0U-)psA+1=r=AyWq2%jvnRgS_%Fdqog~tT#>hN+otgSwg%No}OaH89dM++J6ls$ zQHaRxQ|r(?)0n9sE8QAaJNk=4%XwrVP1QPrk2OxIpG}Cir6N$Sd68RhznRW9XN*ks zRD)eqPjC}YKXgux&GlznmCqpAS=fq}*X>LGSvZfA4dII>T-ZfI}0>C{2| zXSM=`_Z$&V2jAVv7TUg4ha{BzR@`0^hk_rnbKR-o9E*0NZO+;-@R6^JbhOrPCYVVl zCt+%fvnh%#I0{QLF%?lGqcy?=>nIfxuTI-nT}dhO{Z0E^y0WUHY?LB>L*^3|9nVq= zKe4mz&w4aI4WE{S1u1Fgk2(eeM}m)W9nA^7wcc&?jMFBUNgfXs>WuTp?(~fh0RiSY zdjy{#tip}qi6d&iBaPd-qh`Exyh>P(cn&3I_*#`lMc5H~!eU4^X>L)&U4uNG38~+y zh-2>;YUEvl)`ZAM3$}A@|JvjlCMUMtRG&2N8W|cn zf+S+)$2I=hmA5C>t@{WfG2_?;wJyyVpk36t6e6CKTIW|*7m>Y5v4jRp)O5YH!S2Xl zP&j`}pAiU(|3av|CS24v+lPGN=$GO5A*?j4w+d$BYkbAr!#C*BfQn{{dv?MbFL`KU zjGYKe$qS|rphT-QEu+byp{XPtavQg|FXpbQ_p3}Z{AL_bochDRh{Y{AIXJ&CjtCC1 z_jcNL&vZ>lTt>>IoZN~U*%O-Le0zP6VH`g9!z+0XLx%gExs-lFb!~BVEzXIcTdYYl z;Ry!;uCq3SCe%YZELC=8X-Q(6&`)JX)2?-YqstLE$)i775-Yv_O3ccPZhEvO6cyGb zS{H3Qvv!0m4sYEbh8s?X|528~{XYHO^n@KmzurK8hIy-s?}j^qGmQSCEi)k$P|U-_ z6EjX@+$t?AJErgZa+CG;ybsRhuKxa;3VvWy{3o>zZ}9P((ERG^a~ju7oU$#m@A>HX zc+A9$gWzCIa7YOJ$B%!h9>Sf$pO~1NhqhgJ_RW>4CtxNh*UngO(5JD+vpf?xNeHVc ztjS4F>4-@x_R;^L&E-V;(aSKh!q~02F|)`y&Qg~LF&}Lut5!E7$f2Rlse?RV{MRJ+ zdr`%{A%Pt$0aieV#*dZc_uBr6hb{Oei{>B#KKrk30czuAx7Gx|kOpZb!_BOMuCSE3 zfrZ;LX6`wx$`I+?T-Qt=?R1m_LyRk z3stj|-Pxm-Q!Z)4LziLl^Bty~S6Sx8c%N&`rOnTdq+sQ6o-GLmF^zP$ zW={v~Q*{@AiSNJP)L60#ECA|MY0w(N43o^o%XAx^%}r>W(1FnH(!3Hdh%o9PGbges zqN!liMAet2Sd2a8Neat_ivCr2tZ3`UPkQjhfFFj}lproiqYzS;zgbsM#HQKC5Mt4`C15cah*qoN5E;D0QK_F;XW9H@Q5(^qL-Vu@)2Td+jIy+qYwzIYs399p0| zR7$q)E5o0z@ej-Sh#K_q!f@1$hRT=hx_Yc@X!K#D8Wwf~HToJQWnFYLsE&i0e z$^xIua=Qx*w5to~TRA>|ATeuTvx$V?KZ(jsBfxRTYnKcu4VtItMoZ56*Q(#w$uYvKYpxah9gblE_=AXcs<{m@Kk9FE5i3`q6@XDfau+ zp>p>6n37U}`)vLd)Hw7sgr%P-H+6VZ_BD)eoUE4T=y^T`RUa)bHz}YRiiKmgU#0cC ze>tqvKqU^9Dt#@FS3EKD96n^wA>bTu7K77zAL;K+-jTbYBrx~6Hl=mRWky&U*PCa{ zhW~mj%m|larUpN{Zdmw(hp5mG$C{DGhC;yinVAnD z*|495Ji&@uHjVRTU+B#c0r3^+&pK}pL{h9&zyrV`K>FjGL4f^HQh$FRyc}U!=4022 z78cVpf;7Ziw$4NhhtdTT(o60FDq4oNQ>URh?x%=_e&#Bj@n(tJe4P3Lwh7;{^`Wco zB?Vlf+wQ9M3SlQE3!cNDR+0?B*P-@W!<&^@2iFJY_@|XTgO@7g9`;&tvI|e0SN^N9 z$x$;-aY<>6JZ%8UA2hXXjm797FG_?#qc+y7p*SkA;UXX)h@h$V@0VEhz26K0;RiZw z^<#11OCA>X!Hb0JUneYgVz!@wq3YCr+X!-eadVT>kdo>~Z>W7>gnI_Zhu$vsC=@np zOqV2Inf}+Ge#Y$Nu2zM`jGOE|Tc6N3eqQVqG`S#(zjC^Z6HZdVD zCl?hFVbl@lb@GBQO*qokWw21Go2?dJXTbDMy1u@?d>$5f&~tG$0nhY3?Zh|TO1<`B z`@dab16x0bqV@hH6+X(A=_`hM*GenO_doKe@wDgu8a7{nE($QCdV5!=Nvf*MP{T0P zK@}kHlvH)hSVi_%W6yMNC+;T+Tf&S zhVn(A_mzD}>3j_DHUDsR`Q{*w=ahbZYJK|Wf4}?jmWC@IQT{-Sj)l`{2mZP{9mZ8t zbVn4yL$n{#m0LM5X96}h@bKZHAxQof{<_^0ou>Tr;@g|OA-^?T5(J8AHVg#VljRGo z^YvDwCPqfWy_8U)MG~w33u3qOC{JWZ>UCY$9xQQItoJi4zE3##)NyZc2h4m}P|Ou< zX5-DRt(R9bsAF^NeZj3Z-WV?C}9P#=yv`2a1Y9Ck6(9X6S9 zGZ~S{j$YR+5IBflN{OLBYxcBqs|~h?EMeF>m7i0Uf&T&o)!q4qevbC_`7p3t4=q2K zzC$tYNboa0Y{6OYa8CQbuU8p``7`P!V@LWNlMLb+JKrK-OeSaC>BFG)g`aVWww1H1 zYux|;l4lV_2i<+_Z1R#7ns*(!zOJtek$#-h89sX;)uE# z@}UP;!*HNc_7^-=n)=!5jn$MU*}gp7vY*iSX!;(=TGC>}JS6(-o-sW;6p zQgUzoDz~M>+?LIBfBmhgi4%8B1{dHJapOo&5Pd@kGiY29B{4xfl1Aq}^#sfEcB2x( zF`P#IlQwMoDH9c1GX?zO$>3uY7|4|6`DW`cDOKz^Vg=YpLhTxR=Z-K9e`)_o_7 zH=;3L4LPi1ySyu2mFYRvM{2bGEvPFDr*J>b_Ru>uKI7t!O4%7YlHI$7WZ! zO8Qmm;Eq#M8b5mRTn8hNYS?P^lo=zoMZjlf$(ZNdKRmU2KP7rG{ug)uCC_pn`cl`0 zQzoxUJ3n4N$##u%+`Tf?!z$dfBGk<)RNQRLvm*DrG1u4dpO5JX1<0?cMFqpQrX;+# z{F`~QgmS)aG;nwq4QslAv;KoMW>);-SrlHa>5xZpWjU)M8c)On&m~_(4~-|#G4pw@ zX!d?odKX?HDurYXXyg5?srf#mwCU@>5Kmqn=-c---2TF@#*+FY%#52N{MG(MPLq1jWF-rLZ&^%b4x5GCC*(!0``F(+<)d8#;Ar= zpK{3Im4*1)JeMfWTCt`1=&B!nxNU9d9~_j0>zt|+E9hL(_2+iq@^Ew3exp#6S1kKCIbY|GpR-za8mh<%3eg zY|SH4pQJPazg*oj+29d3F*R9r6ZOilhI)R%XwZSQ_OpwhL*u+?vahXUN?klRtf4{K z(QuDXc~$=DrG*%jNgDK1kNU2&?aF3u2d6wYvaez3Mt%B8mHpT?>ji7h`lv-FuNC%=1{nD8@kt8?fk>0F#u+>+(y_eeWJCGw z$!TCnR<^D5E3lzHm=z#vdyZjQ45%nQ$eK3GCNRP$hBOBHJ>(PDv@@)~QR$KUXk-cp znqE!E$zbI3@QuTaU+2rXP)ytw#Hc@|%b2e5;&=59P3!I6Hd2}7(uE!180$sGh*;`| z<^7-@T%B$Wp5e2V_#;VB&TV=JL&+}8qrA;XU)}UvTv6HD6`r;^(ny$1P^4KxbE$^0 z2!5h8Q3|;c2x2;wSABEwIqPYVevcLWZn!5=u~7=!oC1m?A#RLgrUI0j95Xg^GZh@+ zDgN~$mWQN11yu?!@28bG#3sn6DAahSzI;oU?AMsdYGi+miC}M&H}1DHez5DERzcuW zLE=(F;!>SbW}8;xo7Uo*Ruh_DXZC2dzjm5ASlX*`WH8>>8XRwm)a4fNIUMcVtwT?m zkk#RNVoaKlVRZlw7zMXO;lCP(;D5~vId(W7dBF>eDesfWDw~l;ei>))$T3Tm>_vzf zXfh*3v<;iMVJkW|TR#3mDVNsD2Vy~elLuY@D#U@7oV_S!3J>7|Mt!reFcTe(c z8y%DBn4hfO!O@|JQs_~h7?e3moIj1b=hk`@o~<*en*UsBZ-$JKi?+>)RFK5!O2L1^ zl{I0Ur$hr=yvwu#Er_fnwU1d67Vv>GV-W6m6a31j9VS;vr5|=9WSmvnF2w`#Ep??L{yO5U8$|Lj>FjXL$*sFO7qp5teO8&SvhVz7=^2yd=U;GN*zw4 zU1r5?dF+9!CB94~m;f_7{zys%~`Ou2**n;<{bmh2s z$)aT0vUJg$I*?zxm^QVHI+>O(nR*DHW*C?H&*8|((a^|IM|4zMIj4ZuwfXRnIZBLL zQi57sf>MINWT~Vmd3*tRJUQuB1oc-*YAH!tX(<|M>9Rk6iid|whlWarM#|`@&8cZE zX=yBI=qyS~4~k0;OGRgI*x9eyIj#X3ubxCJcv(xF7ljaAmrz+imrXzyAmSXoy&a*v z;YGdSp`fHiMd?0?;-i1UN8j$Hq--ptbTLJl`D6;LDzKvFIV(euf3GZiEH1e(tvoEQ zJS_3O|J^L!?D@sho#OtH9QM8l=AjVgF(1aQ=)tq_v9;*2xyVCwRjHk%vLurc@H$I~ zm3CqhXf&maTXW@Z6HWs16c$YNGujI@#G(QdjTn>GT9=1NnS^ zO~-vp59DxrxCj8i{wR;g&l+u*Qp9%i5O8`H+dA=la$iZa%w>AK7RPzOT7kIfAFIxx zvX9WWZG1AN#6{m@k<2)it>E|RTM@(6`6$SQH=1@9e->5D#S~{1PN<>TnE9ukjL0A# zo8y-MU;o&}+3%Py(c{z=1*<_BHcxvSzAL@00(}yB9g8*FpZ|U_Bg-3$W`yC>7&)@| zkhrsmlA%rf`;P0gryGriXqfHQo~+R8eO!12LAh_;Skt)(f^pD;&+wm_ANyQAo@Zk1 zr&6TLJn=;6GH3q|8XF= z_OiBTE4TJ%?ZDnZVIn1*m$|j2wI4g@EY~cPtgb+Q6kNt~s}?KAFFc3Rq=H}`)$o#F z6}F(7^tQPtVN8xMViC(pFnDkk7I~HwxS>{=yN{@$Lb=IGRs1WHhDUv+WoAIr%Abb8~C!>dUqEv7@Y~taDnr!gN9SX%WsQwzx0X>*{C>3_G04(K>+OZLhNNGH>AAbns-BGaURIyJaI?$}Bymvl>qW(iOX8|+v-XE;K7-da-y3;4T4n5hv8z3Xv(iV)I3^7d0% z{$dtel7pjieIl+Q6&2Zv#-#F3aZqboMKQ%A>pY%x#)I2pZJg{T$xdPF zhl|<`-c#5Fz+nHr)F!22SH;s#wm3MWdZwP%4%lGd8=pihv}nB)&U;5l6|&7H0q9tEv>U2j*O6a!(!u&D+SM4(9j9|P3Bh5Q+QoP*m_0^ zftfasvcMEJSMOcTdGq>F6LXAn!H|TE`CH+J55==7<>0jB>Y9jsEtPXC9|f!Mag%H( zFJYp|XyCHOQ9k{kbHg003WDH9CyLYC=f6>eY0h&n=kYZbs0o(Cc02N45%s&T<)Yl^yp& zy#jg}lN#-R>q}$8QV$)B={H_@xzMx2F~MU{n(l_<_p)^I*d-CLFcs|Gl{~ zVB0ruj5NLeGzIK{@+Q!+23S+&c#f?;G^)2WCjKs_w#%xzfw=Q{VHQ_T$P#RsjYWhl zvGr79SGm`*-i2qYFd;{VaW5s@SroXWuW49y9*ko6TC0;y-z4|sU&gci$IloP?^QJQ zwpC70bJfgfO-->ZbLhD4c)p7=(*r8GoB7HzzZ#SO__)HtfTCl<_DA4Wdev2Y0zY&G zo;w;7#js+mv13WM)y6s@;r=_1GqD??sGv1y-_$cX$}eWu-s0vz3K0qdNycwcp>qCG z;mBCVcmLgPX4aP-Xw%^|MFUlkLQq{rc}*qrVsuX9MOU82*K$;3Y;Dhi<3MQ<67lIR zST;17N_Nli@RQEjgE^hOp27D0solSoe5hZb#Ag7RsHino3aXTnHx88ZxSLf=veX$2 zBWD$F$HZ@2f_)#ea*oWndR$Hy9nwv2PCTdy3iN$mAKsofh`&JJa%MV_?L`>yO3*YJ3OY_@;jiXMzCW`F zb#r~@Q{JsC=_$(XQr7;q@kFu=gx|Hwm=}%~R=HLl91ss${fw#eVI-;A;rxTB|^IdT!J4!oNg_LLZ@Oy=| zVPlCRS&bP~f!2wgJ=smCh00iXiMj0huU6T0NEI8fK4K^FIq@mBYE_s4L=1ZXMZ=J zdnpBIB*er-ybjuMQNz&A0u^mE61|*J zb(QRvf`C8XS$HiFitC#R=W==Z{oCM!@D^%+O~k`%{ie`Xn8ke91FE?Jx#@c&k#$S? zs*q{hXT{dDFK*n_eh?`KVSezzkHz}scLB(Nn}Q{7)VQaVIIxv1a>*Lf`2A(ATPxVD z(3XB%JdN1Pg_qv5bvRORFN1eQ{;-u0(uHF>Me^DQ5jD;L^@&w7$h*aRxvFCNX6dH- z0rHU@F;K;>n`)uQfJ5WJ;|o&!h>AYO4o#2#Rdg9~vdHyoP*LEqMi<}pH3c zV-Hg8-)sN0bVm?HUSHsk9h;DYa|#`6w(Y4-pGWgjNLtD1XN`fZPu(FFwy?ztf(L)N z52n7kOyo7MVz3$M1-o<(st54dqB;uuP5P0yds4=9%&E9u^3T?C3;X$Y(%an-M617bV7~--|G*S8u^R z{VlbrtVP8{kpG0@J&F0EPTNcv6LS|>gJUSUzGV&f0JIWdr|T zj?6S>sr{BHcd4EgCvuc_wnM_y@CSn7F_NjdTf)!#C9LaZ`K}92%C=1jl$E%1ktuO>VkHRPfo} z<}x|COPeO{7m~QWolcgEQbYdlB0%Gy_Mo%*5jR`U>Yw3K`{`LrS3-luk;!zCfOVir zBNND_57PRm_a=s;920HnMBGLcJV5N_b6;JE%mzqU9quWkZC+qP1x`tnAAPMnSpCN% zdxKg$fwOrS;ioG;wq_yCbrCZ#VlV9VDOXSBhSGy*PMnCA{1&gQ~{) zr{I(S8{Qdv@V@a@P|tkTN)NnG$TxzTZTz}GBF}Moxhn!zbKpbJj`uXI_=v2i*Lw8~ zkLNcgV%OE1*#l<}7c63p2!<3N;p0+%{A7{;*a^v&;#|0{taL5x8Tr(sf@R+)*cezg z0^wXs7nW#{PGt3EH1H8sm6oX*9d+@p(6c235#^&syv>+b1#Wk%a*b)Es;*T`3m+DjGn%fRq%Py(Ai z=2AIYUrvv%(CSd>Q?_fveht1PC06=@4E>Qxk54ko$*D#Pdp_h$RyZ;gt9{Dmq82eRy&$#_mxe<8n&VieT z(3v&l$SN(b{LA=le6#3Z#8@}gUdVJ9i%hRm4Nmqz@PT9~!D9U(7iXf4!*qz#?EQ=p zf<*Q+>1RPAU8EA;@4l4--i@we-iF`&cxU8W=C_pf3|!B>{%-ZWhV-qNt}ji)*&LZ% zmPJtq|5t#ySIi~-lgo6a6DF;q0kvz+uP5vPoHU){yJp9E-vcQf*&NX&#ow?=0faW+ zxOp^uo^I5~1L;9UCC=LuSs-^Fh342BYDA;Jq{9B@F%;z|=}ZZ#Bz_23!ii+B6tBv8 zsN1{#F}ug^90Z{-t12gNhe*NCNQn{C>uq0C^0;DIc3%!@xeWc=`$ocXuG(ghpH639 zCo{cI_04u_HkN#*KL(ld1Dr9n~{kO789q~49-_kRDrKi~YE zb7s#u`|PvVTF-jcv&Nco0?TW$@=aQG7`NbxFgzuSc6Xo6YBRn|lE z;;%4E>!{2Itw_iewcue~GOa$jj6A0W4x6^t+`Yc9hw|E^SAet@pFpqP`Vll`n@KI+d#0l-B-pg3yog0_M62OqM9U|)uS?|hwyBJY_ir}oY9 z06tG(SXuvN-%dAel=`6mgp3{9(lo!TqB!k7|n$W%h(6I&?XGC}_|>@rzXr(9(BLsdKB*Kgo0YVew|lT!g%X6p>{aLC!Q$Y6|E!jg@Eg9Voq zB(6d*<2kI!K;DsA??@KE;npKon!m#`Qi1FJST&CDA*zZV_{_{dzKBUbN?%pg(wL*B z89!JxxJlGFg%tHxu?AW(n_oNoa+!3-pIzG<97p;_AQNNj6Ut*7X2MO~cn(ZeUD>U< zcCgcB&25@EU;kBKLJgu-c4tY|3Jml7m)Ff_8@qcGQ*Fc#g*}ZqmY+t|f~<~4!cF`RH{qBJelG{S~UI=g8-GQ1t# z50BYCvn?aiZZ2=6>)M1f(r^ejnvy#q%C)NeAALhn-37n~SNCMUYAdMQzqvRQ!g2>$ zk5(|TE!H*s7FjjA(A!?NjJ7_TI4xImaCMqLG2*lQG);}{~gu6W+Mo;L|v91snP2Ya*ePAolTLF-|r!f*p*l4C2F(PVk zw~^Bcw1kD6?hcLerLI1cqO4qP6M0R%?ruA~hRG4#7WJFA`<$K^ABOBGyJc)-o#%-q zuJ#u9=Q0zV{SEKz{#W8bdAUr8!SzVr=) z=h7*9rBP08*xB;x4&{&q|D1Scro-2R&gd%B(q0VdueL^4-?JDXZCX^YcJKdn9#aCV zWVlcTYe)CU-w-S#qzw|)I-HyPYkn>ca=gimn2I3FK!k;5gaS#_!ARNFGE6qunI6{C zi;uuTGhEDw?u_`QXShWxMI*$-qiGNhJDW-8-u@Pq*XMYLVp}uaFCmAk+1!SutGhM7 z;0_C?tJ2GiT8x{#bM946!U$k0cp)ZtPPU=Br4OE&`M(ZRKCLGfPNQ9w zAuA6_qv%$3;phc8Ao(y2_2ihrl0`0$ttt2YsU-A#VIf<)FTgscW`?Vg<+PPe0N$f2 z6;?vMS1ON{8cW?rYjv%Z1#^zcfs|dhGE-^gJ0{-@K@~&4`jJY^3whHbcs$+Pu+JU6 z;Dg~8M+tNFl&|I( zPS#52H2H|x_2xm%bf_f05)@MK69l{ufHk}wwhF2fg{i#_F`oRu(@4f1-J-XMJ6vmt zyN_7A5>q0c;NipWVa)sC2+G#+O*VQOtbO4H!fNu8stKM)*ZAF>H};8H7=fK*;I6~* zS<6zL<6&VRd*u_1zfZhf67;KqufJ-TnRISM*T`H(7onZ_L zehG-R(#%!EgsQ%+$n0?AyX}MBS|{r|M0%Kq4MEGGrm;A1CMe!zxU(0O45=*&y}BLW zg9o?M^@F%Ly89%P>kOM&nuJzLiT1CD*3~CO1|(j$`gZCo{pM(WEZ9*-t#k6OqD}Fj zytevmP;!k)`?a<-65N(GDo3_M&wZr`=HZd+c&0ZNE#D@B{9^j2y%ppwe4a>@5&(BW zD2neEkYIkK@1>~Ue^R($`#OimJ{D0kMqye8@2$ZqH2s>hfr^((yV+7wj_W?EVBzR8 zSvCFZ((OO5E=pmJaPrKDcq;0QVMMEPfX+xFbA^5uYSyA}7qQi)3y49vUQWjPhFQsIhKmCvk^E0)-R}@QDIr&&xS# zm@&=t*`Uoz5}Juaz|fVZ*K|UzMrS}eicV#v22r)?n;4BY(5uqnE$ZQ?Kn!!>%3(;PI0VbrG*cx`Y&jo5 zm^;ETS(on<*}M)|&R3bZ7dH|e?wZoCC)k4yp)D#@+8JXutY)!NgYV7kx!OwJXzRy) zNGOpnj3L+xyp0w^39ekmv(t>2o!M(L@qXO>iJ9&cLPZ=qmA4)=|46UA>oS5r^gRlB zwSAQ{dVhk!U6mR+>v!Zzf!UAV3E**8^R@VIFgm(?0Nj#h-yEC@#Ht>pQ6z5`=1xo;U<5?fQ(LzRxtAs z2oRV6%N}eyUhdPUYg=G|2W%ad(a#Vt>yd3V5 z19S;iCYdg&*Bj1|a9UI|W*eXGHwt0@zy#-)hG=Mh{JzcdW7;G)BVb|`ay^L!OstEM zX8F!paD6dil9DdRItD(rjyBG0Xm6r~cM7MhSr^!aPbJ7tUESFsfx(io?98~oF;c*t zjWYG!Hn*&Y3D01PMB6YzZfv&$uBywnIXc2qe?7p$F@w<2L zE{Ex$l^)&PM1aGEg@u`!nJ6IG73dS3mBj+I+`yT?pWmYNV#IBdtU)WF{O3%xeUFWu zo%QZrP_5DN=CAXK-F64eARs0tCMY-%!;;wu)WwdE9~?RmonyzR8{mr0BWyrjO1N~F z@m_jgEK)q#QKpL}A{$#lcmXzFH(USzIBoe1H;6t(NY%?Q`3m?vc?^~REN2w>AGR5< A1poj5 literal 0 HcmV?d00001 diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index c24faac..00f702a 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -10,4 +10,5 @@ quantization sleep_mode structured_output lora +eplb_swift_balancer ::: diff --git a/docs/source/user_guide/feature_guide/quantization.md b/docs/source/user_guide/feature_guide/quantization.md index 9e5f56c..5300ad5 100644 --- a/docs/source/user_guide/feature_guide/quantization.md +++ b/docs/source/user_guide/feature_guide/quantization.md @@ -108,18 +108,19 @@ Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_2 ### 3. When converting deepseek series models with modelslim, what should you pay attention? -When using the weight generated by modelslim with the `--dynamic` parameter, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results. +When the mla portion of the weights used `W8A8_DYNAMIC` quantization, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results. The operation steps are as follows: 1. Search in the CANN package directory used, for example: find /usr/local/Ascend/ -name fusion_config.json -2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows: +2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows: ```bash { "Switch":{ "GraphFusion":{ "AddRmsNormDynamicQuantFusionPass":"off", + "MultiAddRmsNormDynamicQuantFusionPass":"off", ``` diff --git a/docs/source/user_guide/release_notes.md b/docs/source/user_guide/release_notes.md index 75faf29..3a0bcca 100644 --- a/docs/source/user_guide/release_notes.md +++ b/docs/source/user_guide/release_notes.md @@ -1,5 +1,70 @@ # Release note +## v0.11.0rc0 - 2025.09.30 + +This is the special release candidate of v0.11.0 for vLLM Ascend. Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/) to get started. + +### Highlights + +- DeepSeek V3.2 is supported now. [#3270](https://github.com/vllm-project/vllm-ascend/pull/3270) +- Qwen3-vl is supported now. [#3103](https://github.com/vllm-project/vllm-ascend/pull/3103) + +### Core + +- DeepSeek works with aclgraph now. [#2707](https://github.com/vllm-project/vllm-ascend/pull/2707) +- MTP works with aclgraph now. [#2932](https://github.com/vllm-project/vllm-ascend/pull/2932) +- EPLB is supported now. [#2956](https://github.com/vllm-project/vllm-ascend/pull/2956) +- Mooncacke store kvcache connector is supported now. [#2913](https://github.com/vllm-project/vllm-ascend/pull/2913) +- CPU offload connector is supported now. [#1659](https://github.com/vllm-project/vllm-ascend/pull/1659) + +### Other + +- Qwen3-next is stable now. [#3007](https://github.com/vllm-project/vllm-ascend/pull/3007) +- Fixed a lot of bugs introduced in v0.10.2 by Qwen3-next. [#2964](https://github.com/vllm-project/vllm-ascend/pull/2964) [#2781](https://github.com/vllm-project/vllm-ascend/pull/2781) [#3070](https://github.com/vllm-project/vllm-ascend/pull/3070) [#3113](https://github.com/vllm-project/vllm-ascend/pull/3113) +- The LoRA feature is back now. [#3044](https://github.com/vllm-project/vllm-ascend/pull/3044) +- Eagle3 spec decode method is back now. [#2949](https://github.com/vllm-project/vllm-ascend/pull/2949) + +## v0.10.2rc1 - 2025.09.16 + +This is the 1st release candidate of v0.10.2 for vLLM Ascend. Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/) to get started. + +### Highlights + +- Add support for Qwen3 Next. Please note that expert parallel and MTP feature doesn't work with this release. We'll make it work enough soon. Follow the [official guide](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) to get start [#2917](https://github.com/vllm-project/vllm-ascend/pull/2917) +- Add quantization support for aclgraph [#2841](https://github.com/vllm-project/vllm-ascend/pull/2841) + +### Core + +- Aclgraph now works with Ray backend. [#2589](https://github.com/vllm-project/vllm-ascend/pull/2589) +- MTP now works with the token > 1. [#2708](https://github.com/vllm-project/vllm-ascend/pull/2708) +- Qwen2.5 VL now works with quantization. [#2778](https://github.com/vllm-project/vllm-ascend/pull/2778) +- Improved the performance with async scheduler enabled. [#2783](https://github.com/vllm-project/vllm-ascend/pull/2783) +- Fixed the performance regression with non MLA model when use default scheduler. [#2894](https://github.com/vllm-project/vllm-ascend/pull/2894) + +### Other +- The performance of w8a8 quantization is improved. [#2275](https://github.com/vllm-project/vllm-ascend/pull/2275) +- The performance of moe model is improved. [#2689](https://github.com/vllm-project/vllm-ascend/pull/2689) [#2842](https://github.com/vllm-project/vllm-ascend/pull/2842) +- Fixed resources limit error when apply speculative decoding and aclgraph. [#2472](https://github.com/vllm-project/vllm-ascend/pull/2472) +- Fixed the git config error in docker images. [#2746](https://github.com/vllm-project/vllm-ascend/pull/2746) +- Fixed the sliding windows attention bug with prefill. [#2758](https://github.com/vllm-project/vllm-ascend/pull/2758) +- The official doc for Prefill Decode Disaggregation with Qwen3 is added. [#2751](https://github.com/vllm-project/vllm-ascend/pull/2751) +- `VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP` env works again. [#2740](https://github.com/vllm-project/vllm-ascend/pull/2740) +- A new improvement for oproj in deepseek is added. Set `oproj_tensor_parallel_size` to enable this feature[#2167](https://github.com/vllm-project/vllm-ascend/pull/2167) +- Fix a bug that deepseek with torchair doesn't work as expect when `graph_batch_sizes` is set. [#2760](https://github.com/vllm-project/vllm-ascend/pull/2760) +- Avoid duplicate generation of sin_cos_cache in rope when kv_seqlen > 4k. [#2744](https://github.com/vllm-project/vllm-ascend/pull/2744) +- The performance of Qwen3 dense model is improved with flashcomm_v1. Set `VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1` and `VLLM_ASCEND_ENABLE_FLASHCOMM=1` to enable it. [#2779](https://github.com/vllm-project/vllm-ascend/pull/2779) +- The performance of Qwen3 dense model is improved with prefetch feature. Set `VLLM_ASCEND_ENABLE_PREFETCH_MLP=1` to enable it. [#2816](https://github.com/vllm-project/vllm-ascend/pull/2816) +- The performance of Qwen3 MoE model is improved with rope ops update. [#2571](https://github.com/vllm-project/vllm-ascend/pull/2571) +- Fix the weight load error for RLHF case. [#2756](https://github.com/vllm-project/vllm-ascend/pull/2756) +- Add warm_up_atb step to speed up the inference. [#2823](https://github.com/vllm-project/vllm-ascend/pull/2823) +- Fixed the aclgraph steam error for moe model. [#2827](https://github.com/vllm-project/vllm-ascend/pull/2827) + +### Known issue +- The server will be hang when running Prefill Decode Disaggregation with different TP size for P and D. It's fixed by [vLLM commit](https://github.com/vllm-project/vllm/pull/23917) which is not included in v0.10.2. You can pick this commit to fix the issue. +- The HBM usage of Qwen3 Next is higher than expected. It's a [known issue](https://github.com/vllm-project/vllm-ascend/issues/2884) and we're working on it. You can set `max_model_len` and `gpu_memory_utilization` to suitable value basing on your parallel config to avoid oom error. +- We notice that lora doesn't work with this release due to the refactor of kv cache. We'll fix it soon. [2941](https://github.com/vllm-project/vllm-ascend/issues/2941) +- Please do not enable chunked prefill with prefix cache when running with Ascend scheduler. The performance and accuracy is not good/correct. [#2943](https://github.com/vllm-project/vllm-ascend/issues/2943) + ## v0.10.1rc1 - 2025.09.04 This is the 1st release candidate of v0.10.1 for vLLM Ascend. Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/) to get started. diff --git a/examples/disaggregated_prefill_v1/README.md b/examples/disaggregated_prefill_v1/README.md index eec8924..fabcf6b 100644 --- a/examples/disaggregated_prefill_v1/README.md +++ b/examples/disaggregated_prefill_v1/README.md @@ -42,7 +42,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5559 +export VLLM_ASCEND_LLMDD_RPC_PORT=5559 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ @@ -70,9 +70,7 @@ vllm serve /models/deepseek_r1_w8a8 \ "kv_port": "20001", "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"chunked_prefill_for_mla":true}' + }' ``` Run prefill server P2 on second node: @@ -85,7 +83,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5659 +export VLLM_ASCEND_LLMDD_RPC_PORT=5659 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ @@ -114,9 +112,7 @@ vllm serve /models/deepseek_r1_w8a8 \ "kv_port": "20001", "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"chunked_prefill_for_mla":true}' + }' ``` Run decode server d1 on third node: @@ -131,7 +127,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5759 +export VLLM_ASCEND_LLMDD_RPC_PORT=5759 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ @@ -173,7 +169,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5859 +export VLLM_ASCEND_LLMDD_RPC_PORT=5859 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ diff --git a/examples/disaggregated_prefill_v1/gen_ranktable.py b/examples/disaggregated_prefill_v1/gen_ranktable.py index 52db3ee..ad86c84 100644 --- a/examples/disaggregated_prefill_v1/gen_ranktable.py +++ b/examples/disaggregated_prefill_v1/gen_ranktable.py @@ -17,6 +17,10 @@ parser.add_argument("--decode-device-cnt", type=int, required=True, help="number of decode devices") +parser.add_argument("--local-device-ids", + type=str, + required=False, + help="local device ids") args = parser.parse_args() local_host = args.local_host prefill_device_cnt = args.prefill_device_cnt @@ -54,39 +58,49 @@ chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split( "\n")[0].split(":")[1].strip() chips_per_card = int(chips_per_card) +if args.local_device_ids: + local_device_ids = args.local_device_ids.split(',') +else: + local_device_ids = [] + for card_id in range(num_cards): + for chip_id in range(chips_per_card): + device_id = card_id * chips_per_card + chip_id + local_device_ids.append(device_id) + # generate local device list for local rank 0, and gather it to all ranks local_device_list: list[dict[str, str]] = list() if local_rank == "0": super_pod_id = "0" - for card_id in range(num_cards): - for chip_id in range(chips_per_card): - device_id = card_id * chips_per_card + chip_id - if soc_info == AscendSocVersion.A3: - device_ip = get_cmd_stdout( - f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr" - ).split(":")[1].strip() - super_device_id = get_cmd_stdout( - f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID" - ).split(":")[1].strip() - super_pod_id = get_cmd_stdout( - f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\"" - ).split(":")[1].strip() - else: - device_ip = get_cmd_stdout( - f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr" - ).split(":")[1].strip() + for idx in range(len(local_device_ids)): + device_id = local_device_ids[idx] + chip_id = device_id % chips_per_card + card_id = device_id // chips_per_card + if soc_info == AscendSocVersion.A3: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr" + ).split(":")[1].strip() + super_device_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID" + ).split(":")[1].strip() + super_pod_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\"" + ).split(":")[1].strip() + else: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr" + ).split(":")[1].strip() - device_info = { - "server_id": local_host, - "device_id": str(device_id), - "device_ip": str(device_ip), - } - if soc_info == AscendSocVersion.A3: - device_info.update({ - "super_pod_id": str(super_pod_id), - "super_device_id": str(super_device_id) - }) - local_device_list.append(device_info) + device_info = { + "server_id": local_host, + "device_id": str(device_id), + "device_ip": str(device_ip), + } + if soc_info == AscendSocVersion.A3: + device_info.update({ + "super_pod_id": str(super_pod_id), + "super_device_id": str(super_device_id) + }) + local_device_list.append(device_info) dist.init_process_group(backend=dist.Backend.GLOO) global_device_list = [None] * dist.get_world_size() diff --git a/examples/disaggregated_prefill_v1/gen_ranktable.sh b/examples/disaggregated_prefill_v1/gen_ranktable.sh index e8a923a..8abe5ed 100644 --- a/examples/disaggregated_prefill_v1/gen_ranktable.sh +++ b/examples/disaggregated_prefill_v1/gen_ranktable.sh @@ -33,6 +33,11 @@ while [[ $# -gt 0 ]]; do DECODE_DEVICE_CNT="$1" shift ;; + --local-device-ids) + shift + LOCAL_DEVICE_IDS="$1" + shift + ;; esac done LOCAL_HOSTS=($(hostname -I)) @@ -68,6 +73,10 @@ echo "NNODES": $NNODES echo "NODE_RANK": $NODE_RANK echo "===============" +if [ -n "$LOCAL_DEVICE_IDS" ]; then + OPTIONAL_SECTION=" --local-device-ids $LOCAL_DEVICE_IDS" +fi + if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \ --nproc_per_node 1 \ @@ -75,5 +84,5 @@ if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then --node_rank ${NODE_RANK} \ --master_addr ${MASTER_ADDR} \ --master_port ${MASTER_PORT} \ - gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT + gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT $OPTIONAL_SECTION fi diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index 727233e..2728931 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -363,6 +363,7 @@ async def send_request_to_service(client: httpx.AsyncClient, } req_data["stream"] = False req_data["max_tokens"] = 1 + req_data["min_tokens"] = 1 if "stream_options" in req_data: del req_data["stream_options"] headers = { diff --git a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md new file mode 100644 index 0000000..3bf9240 --- /dev/null +++ b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md @@ -0,0 +1,272 @@ +# Mooncacke Store Deployment Guide + +## Environmental Dependencies + +* Software: + * Python >= 3.9, < 3.12 + * CANN >= 8.2.rc1 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 + * vLLM:main branch + * vLLM-Ascend:main branch + * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.) + Installation and Compilation Guide:https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy?tab=readme-ov-file#build-and-use-binaries + +## run mooncake master + +### 1.Configure mooncake.json + +The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path where mooncake.json is located. + +``` +{ + "local_hostname": "xx.xx.xx.xx", + "metadata_server": "P2PHANDSHAKE", + "protocol": "ascend", + "device_name": "", + "master_server_address": "xx.xx.xx.xx:50088", + "global_segment_size": 30000000000 +} +``` + +**local_hostname**: Configured as the IP address of the current master node, +**metadata_server**: Configured as **P2PHANDSHAKE**, +**protocol:** Configured for Ascend to use Mooncake's HCCL communication, +**device_name**: "" +**master_server_address**: Configured with the IP and port of the master service +**global_segment_size**: Expands the kvcache size registered by the PD node to the master + +### 2. Start mooncake_master + +Under the mooncake folder: + +``` +mooncake_master --port 50088 +``` + +## Pooling and Prefill Decode Disaggregate Scenario + +### 1.Run `prefill` Node and `decode` Node + +Using MultiConnector to simultaneously utilize both p2p connectors and pooled connectors. P2P performs kv_transfer, while pooling creates a larger prefix-cache. + +`prefill` Node: + +``` +bash multi_producer.sh +``` + +The content of the multi_producer.sh script: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export ASCEND_TRANSPORT_PRINT=1 +export ACL_OP_INIT_MODE=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8100 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MultiConnector", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "use_layerwise": false, + "connectors": [ + { + "kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_producer", + "kv_port": "20001", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 1 + }, + "decode": { + "dp_size": 1, + "tp_size": 1 + } + } + }, + { + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_producer", + "mooncake_rpc_port":"0" + } + ] + } +}' > p.log 2>&1 +``` + +`decode` Node: + +``` +bash multi_consumer.sh +``` + +The content of multi_consumer.sh: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7 +export ACL_OP_INIT_MODE=1 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8200 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MultiConnector", + "kv_role": "kv_consumer", + "kv_connector_extra_config": { + "use_layerwise": false, + "connectors": [ + { + "kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_consumer", + "kv_port": "20002", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 1 + }, + "decode": { + "dp_size": 1, + "tp_size": 1 + } + } + }, + { + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_consumer", + "mooncake_rpc_port":"1" + } + ] + } + }' > d.log 2>&1 +``` + +### 2、Start proxy_server. + +``` +bash proxy.sh +``` + +proxy.sh content: +Change localhost to your actual IP address. + +``` +python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \ + --host localhost\ + --prefiller-hosts localhost \ + --prefiller-ports 8100 \ + --decoder-hosts localhost\ + --decoder-ports 8200 \ +``` + +### 3. Run Inference + +Configure the localhost, port, and model weight path in the command to your own settings. + +Short question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +``` + +Long question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +``` + +## Pooling and Mixed Deployment Scenario + +### 1、Run Mixed Department Script + +The mixed script is essentially a pure pooling scenario for the P node. + +``` +bash mixed_department.sh +``` + +Content of mixed_department.sh: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export ACL_OP_INIT_MODE=1 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8100 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "use_layerwise": false, + "mooncake_rpc_port":"0" + } +}' > mix.log 2>&1 +``` + +### 2. Run Inference + +Configure the localhost, port, and model weight path in the command to your own settings. The requests sent will only go to the port where the mixed deployment script is located, and there is no need to start a separate proxy. + +Short question: + +``` +curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +``` + +Long question: + +``` +curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +``` \ No newline at end of file diff --git a/examples/external_online_dp/run_dp_template.sh b/examples/external_online_dp/run_dp_template.sh index 661bdfa..70f27fe 100644 --- a/examples/external_online_dp/run_dp_template.sh +++ b/examples/external_online_dp/run_dp_template.sh @@ -43,4 +43,4 @@ vllm serve model_path \ "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' \ --additional-config \ - '{"ascend_scheduler_config": {"enabled": true}, "torchair_graph_config":{"enabled":true,"enable_kv_nz":false, "enable_multistream_moe":false, "graph_batch_size":[28]}, "enable_weight_nz_layout":true}' \ No newline at end of file + '{"ascend_scheduler_config": {"enabled": true}, "torchair_graph_config":{"enabled":true,"enable_kv_nz":false, "graph_batch_size":[28]}, "enable_weight_nz_layout":true, "enable_multistream_moe":false}' \ No newline at end of file diff --git a/examples/offline_disaggregated_prefill_npu.py b/examples/offline_disaggregated_prefill_npu.py index f37b508..0bf69fc 100644 --- a/examples/offline_disaggregated_prefill_npu.py +++ b/examples/offline_disaggregated_prefill_npu.py @@ -79,7 +79,7 @@ def run_prefill(prefill_done, process_close): def run_decode(prefill_done): - os.environ['VLLM_LLMDD_RPC_PORT'] = '6634' + os.environ['VLLM_ASCEND_LLMDD_RPC_PORT'] = '6634' # ranktable.json needs be generated using gen_ranktable.sh # from the examples/disaggregated_prefill_v1 module in the main branch. os.environ['DISAGGREGATED_PREFILL_RANK_TABLE_PATH'] = "./ranktable.json" diff --git a/examples/offline_weight_load.py b/examples/offline_weight_load.py new file mode 100644 index 0000000..a08ed2d --- /dev/null +++ b/examples/offline_weight_load.py @@ -0,0 +1,326 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py + +# Note: This script is designed to run with e2e test, +# please be careful to modify it. +""" +Usage: +Single node: + Dense models: + python examples/offline_weight_load.py \ + --model="Qwen/Qwen2.5-0.5B-Instruct" \ + --tp-size=1 \ + --proc-per-node=2 + MOE models: + python examples/offline_weight_load.py \ + --model="Qwen/Qwen3-30B-A3B" \ + --tp-size=2 \ + --proc-per-node=2 \ + --enable-expert-parallel + +Multi-node: + Node 0 (assume the node has ip of 10.99.48.128): + python examples/offline_weight_load.py \ + --model="Qwen/Qwen3-30B-A3B" \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=0 \ + --proc-per-node=2 \ + --enable-expert-parallel \ + --master-addr=10.99.48.128 \ + --master-port=13345 + Node 1: + python examples/offline_weight_load.py \ + --model="Qwen/Qwen3-30B-A3B" \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=1 \ + --enable-expert-parallel \ + --master-addr=10.99.48.128 \ + --master-port=13345 +""" + +import argparse +import contextlib +import gc +import os +from multiprocessing import Process +from time import sleep + +import torch +from vllm import LLM, SamplingParams +from vllm.distributed.parallel_state import ( # noqa E402 + destroy_distributed_environment, destroy_model_parallel, get_tp_group) +from vllm.utils import get_open_port, GiB_bytes +from safetensors.torch import load_file + +os.environ["VLLM_USE_MODELSCOPE"] = "True" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +def patch_vllm_moe_model_weight_loader(model): + # Define MLP attribute mapping for different model types + + model = getattr(model, "model", None) or getattr(model, "language_model", None) + if model is None: + raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.") + + for layer in model.layers: + mlp_attr = "mlp" + mlp = getattr(layer, mlp_attr) + + param_dict = dict(mlp.named_parameters()) + for name, param in param_dict.items(): + if "w13_weight" in name or "w2_weight" in name: + param.weight_loader = mlp.experts.weight_loader + +def load_and_merge_safetensors(directory): + merged_dict = {} + + if not os.path.isdir(directory): + raise ValueError(f"directory is not exist : {directory}") + + for filename in os.listdir(directory): + if filename.endswith('.safetensors'): + file_path = os.path.join(directory, filename) + print(f"loading file: {file_path}") + + f = load_file(file_path) + merged_dict.update(f) + + return merged_dict + +def parse_args(): + + parser = argparse.ArgumentParser(description="External launcher Inference") + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-0.6B", + help="Model name or path", + ) + parser.add_argument("--tp-size", + type=int, + default=1, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--proc-per-node", + type=int, + default=1, + help="Number of processes per node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") + parser.add_argument("--enforce-eager", + action="store_true", + help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action="store_true", + help="Trust remote code.") + parser.add_argument("--enable-expert-parallel", + action="store_true", + help="Enable expert parallel, used in MOE models.") + parser.add_argument("--enable-sleep-mode", + action="store_true", + help="Enable sleep mode for the engine.") + parser.add_argument("--temperature", + type=float, + default=0.8, + help="Float that controls the randomness of the sampling.") + parser.add_argument("--model-weight-gib", + type=float, + default=None, + help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).") + + args = parser.parse_args() + if args.enable_sleep_mode: + if args.model_weight_gib is None or args.temperature != 0: + parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.") + if args.model_weight_gib <= 0: + parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.") + if args.model == parser.get_default("model") and args.model_weight_gib is None: + parser.error("model-weight-gib must be provided for default model when enable-sleep-mode is set.") + + return args + + +def main( + local_rank: int, + rank: int, + master_addr: str, + master_port: int, + model_weight_gib: float, + model: str = "Qwen/Qwen3-30B-A3B", + world_size: int = 4, + tensor_parallel_size: int = 2, + enable_expert_parallel: bool = False, + enforce_eager: bool = True, + trust_remote_code: bool = True, + enable_sleep_mode: bool = False, + temperature: float = 0.8, +): + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend="cpu:gloo,npu:hccl", + world_size=world_size, + rank=rank, + ) + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 10 + sampling_params = SamplingParams( + temperature=temperature, + top_p=0.95, + max_tokens=10, + ) + llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + enable_expert_parallel=enable_expert_parallel, + enforce_eager=enforce_eager, + trust_remote_code=trust_remote_code, + distributed_executor_backend="external_launcher", + seed=0, + gpu_memory_utilization = 0.95, + enable_sleep_mode=enable_sleep_mode, + ) + model_path = model + runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model + patch_vllm_moe_model_weight_loader(runmodel) + sd = load_and_merge_safetensors(model_path) + runmodel.load_weights(sd.items()) + print('load state dict done') + tp_ranks = get_tp_group().ranks + print(f'TP RANKS: {tp_ranks}') + + outputs = llm.generate(prompts, sampling_params) + + if enable_sleep_mode: + if rank == 0: + free_bytes_before_sleep, total = torch.npu.mem_get_info() + llm.sleep(level=1) + if rank == 0: + free_bytes_after_sleep, total = torch.npu.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB") + # now the freed memory should be larger than the model weights + assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes + + llm.wake_up() + outputs_after_wakeup = llm.generate(prompts, sampling_params) + if rank == 0: + # cmp output + assert outputs[0].outputs[0].text == outputs_after_wakeup[0].outputs[0].text + print("Sleep and wake up successfully!!") + + for i, output in enumerate(outputs): + if i >= 5: + # print only 5 outputs + break + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Global rank: {rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + + # Give engines time to pause their processing loops before exiting. + sleep(5) + del llm + cleanup_env_and_memory() + + +def cleanup_env_and_memory(): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + +if __name__ == "__main__": + args = parse_args() + + tp_size = args.tp_size + node_size = args.node_size + proc_per_node = args.proc_per_node + node_rank = args.node_rank + + if node_size == 1: + master_addr = "127.0.0.1" + master_port = get_open_port() + else: + master_addr = args.master_addr + master_port = args.master_port + + world_size = node_size * proc_per_node + + procs = [] + for local_rank, rank in enumerate( + range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))): + proc = Process(target=main, + args=( + local_rank, + rank, + master_addr, + master_port, + args.model_weight_gib, + args.model, + world_size, + tp_size, + args.enable_expert_parallel, + args.enforce_eager, + args.trust_remote_code, + args.enable_sleep_mode, + args.temperature, + )) + + proc.start() + procs.append(proc) + exit_code = 0 + for proc in procs: + proc.join(timeout=600) + if proc.exitcode is None: + print( + f"Killing process {proc.pid} that didn't stop within 30 minutes." + ) + proc.kill() + exit_code = 1 + elif proc.exitcode: + exit_code = proc.exitcode + + exit(exit_code) diff --git a/examples/run_dp_server.sh b/examples/run_dp_server.sh index 1866fb0..9725812 100644 --- a/examples/run_dp_server.sh +++ b/examples/run_dp_server.sh @@ -29,4 +29,4 @@ vllm serve Qwen/Qwen1.5-MoE-A2.7B \ --gpu-memory-utilization 0.9 \ --trust-remote-code \ --enforce-eager \ - --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}' + --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "use_cached_graph":false}}' diff --git a/requirements-dev.txt b/requirements-dev.txt index 9be7f39..13864a0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,7 +5,7 @@ openai pytest >= 6.0 pytest-asyncio pytest-mock -lm-eval==0.4.8 +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d types-jsonschema xgrammar zmq diff --git a/tests/e2e/common.sh b/tests/e2e/common.sh index 3c61524..bb99b38 100644 --- a/tests/e2e/common.sh +++ b/tests/e2e/common.sh @@ -14,7 +14,7 @@ _err() { _red "Error: $*" && exit 1; } CURL_TIMEOUT=1 CURL_COOLDOWN=5 -CURL_MAX_TRIES=180 +CURL_MAX_TRIES=300 function wait_url_ready() { local serve_name="$1" diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 430153a..d0f1b76 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -32,7 +32,14 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm import LLM, SamplingParams -from vllm.config import TaskOption, _get_and_verify_dtype + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.10.2"): + from vllm.config import TaskOption, _get_and_verify_dtype +else: + from vllm.config.model import TaskOption, _get_and_verify_dtype + from vllm.inputs import TextPrompt from vllm.outputs import RequestOutput from vllm.transformers_utils.utils import maybe_model_redirect diff --git a/tests/e2e/doctests/001-quickstart-test.sh b/tests/e2e/doctests/001-quickstart-test.sh index 6490908..43366ae 100755 --- a/tests/e2e/doctests/001-quickstart-test.sh +++ b/tests/e2e/doctests/001-quickstart-test.sh @@ -57,8 +57,8 @@ function quickstart_online_test() { } _info "====> Start simple_test" -simple_test +time simple_test _info "====> Start quickstart_offline_test" -quickstart_offline_test +time quickstart_offline_test _info "====> Start quickstart_online_test" -quickstart_online_test +time quickstart_online_test diff --git a/tests/e2e/doctests/002-pip-binary-installation-test.sh b/tests/e2e/doctests/002-pip-binary-installation-test.sh index a763cef..525f348 100644 --- a/tests/e2e/doctests/002-pip-binary-installation-test.sh +++ b/tests/e2e/doctests/002-pip-binary-installation-test.sh @@ -59,4 +59,4 @@ function install_binary_test() { } _info "====> Start install_binary_test" -install_binary_test +time install_binary_test diff --git a/tests/e2e/model_utils.py b/tests/e2e/model_utils.py index 1a3ea5b..e5b353e 100644 --- a/tests/e2e/model_utils.py +++ b/tests/e2e/model_utils.py @@ -19,7 +19,12 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import PromptLogprobs, SampleLogprobs +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.10.2"): + from vllm.sequence import PromptLogprobs, SampleLogprobs +else: + from vllm.logprobs import PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] diff --git a/tests/e2e/models/configs/DeepSeek-V2-Lite.yaml b/tests/e2e/models/configs/DeepSeek-V2-Lite.yaml index 7df0544..58af318 100644 --- a/tests/e2e/models/configs/DeepSeek-V2-Lite.yaml +++ b/tests/e2e/models/configs/DeepSeek-V2-Lite.yaml @@ -1,12 +1,16 @@ model_name: "deepseek-ai/DeepSeek-V2-Lite" +runner: "linux-aarch64-a2-2" +hardware: "Atlas A2 Series" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.375 + value: 0.385 - name: "exact_match,flexible-extract" - value: 0.375 + value: 0.385 tensor_parallel_size: 2 +batch_size: 32 +gpu_memory_utilization: 0.7 apply_chat_template: False fewshot_as_multiturn: False trust_remote_code: True diff --git a/tests/e2e/models/configs/Qwen2.5-VL-7B-Instruct.yaml b/tests/e2e/models/configs/Qwen2.5-VL-7B-Instruct.yaml index eb7196a..3543e0c 100644 --- a/tests/e2e/models/configs/Qwen2.5-VL-7B-Instruct.yaml +++ b/tests/e2e/models/configs/Qwen2.5-VL-7B-Instruct.yaml @@ -1,4 +1,6 @@ model_name: "Qwen/Qwen2.5-VL-7B-Instruct" +runner: "linux-aarch64-a2-1" +hardware: "Atlas A2 Series" model: "vllm-vlm" tasks: - name: "mmmu_val" diff --git a/tests/e2e/models/configs/Qwen3-30B-A3B.yaml b/tests/e2e/models/configs/Qwen3-30B-A3B.yaml index be1bbb0..6b04252 100644 --- a/tests/e2e/models/configs/Qwen3-30B-A3B.yaml +++ b/tests/e2e/models/configs/Qwen3-30B-A3B.yaml @@ -1,4 +1,6 @@ model_name: "Qwen/Qwen3-30B-A3B" +runner: "linux-aarch64-a2-2" +hardware: "Atlas A2 Series" tasks: - name: "gsm8k" metrics: diff --git a/tests/e2e/models/configs/Qwen3-8B-Base.yaml b/tests/e2e/models/configs/Qwen3-8B-Base.yaml index e60cc9a..2124361 100644 --- a/tests/e2e/models/configs/Qwen3-8B-Base.yaml +++ b/tests/e2e/models/configs/Qwen3-8B-Base.yaml @@ -1,4 +1,6 @@ model_name: "Qwen/Qwen3-8B-Base" +runner: "linux-aarch64-a2-1" +hardware: "Atlas A2 Series" tasks: - name: "gsm8k" metrics: diff --git a/tests/e2e/models/configs/accuracy.txt b/tests/e2e/models/configs/accuracy.txt index e29ff1a..2184a59 100644 --- a/tests/e2e/models/configs/accuracy.txt +++ b/tests/e2e/models/configs/accuracy.txt @@ -1,3 +1,4 @@ +DeepSeek-V2-Lite.yaml Qwen3-8B-Base.yaml Qwen2.5-VL-7B-Instruct.yaml Qwen3-30B-A3B.yaml \ No newline at end of file diff --git a/tests/e2e/models/report_template.md b/tests/e2e/models/report_template.md index 8402545..81dd717 100644 --- a/tests/e2e/models/report_template.md +++ b/tests/e2e/models/report_template.md @@ -2,16 +2,28 @@ - **vLLM Version**: vLLM: {{ vllm_version }} ([{{ vllm_commit[:7] }}](https://github.com/vllm-project/vllm/commit/{{ vllm_commit }})), **vLLM Ascend Version**: {{ vllm_ascend_version }} ([{{ vllm_ascend_commit[:7] }}](https://github.com/vllm-project/vllm-ascend/commit/{{ vllm_ascend_commit }})) - **Software Environment**: **CANN**: {{ cann_version }}, **PyTorch**: {{ torch_version }}, **torch-npu**: {{ torch_npu_version }} -- **Hardware Environment**: Atlas A2 Series +- **Hardware Environment**: {{ hardware }} - **Parallel mode**: {{ parallel_mode }} -- **Execution mode**: ACLGraph +- **Execution mode**: {{ execution_model }} **Command**: ```bash export MODEL_ARGS={{ model_args }} lm_eval --model {{ model_type }} --model_args $MODEL_ARGS --tasks {{ datasets }} \ -{% if apply_chat_template %} --apply_chat_template {{ apply_chat_template }} {% endif %} {% if fewshot_as_multiturn %} --fewshot_as_multiturn {{ fewshot_as_multiturn }} {% endif %} {% if num_fewshot is defined and num_fewshot != "N/A" %} --num_fewshot {{ num_fewshot }} {% endif %} {% if limit is defined and limit != "N/A" %} --limit {{ limit }} {% endif %} --batch_size {{ batch_size}} +{% if apply_chat_template is defined and (apply_chat_template|string|lower in ["true", "1"]) -%} + --apply_chat_template \ +{%- endif %} +{% if fewshot_as_multiturn is defined and (fewshot_as_multiturn|string|lower in ["true", "1"]) -%} + --fewshot_as_multiturn \ +{%- endif %} +{% if num_fewshot is defined and num_fewshot != "N/A" -%} + --num_fewshot {{ num_fewshot }} \ +{%- endif %} +{% if limit is defined and limit != "N/A" -%} + --limit {{ limit }} \ +{%- endif %} +--batch_size {{ batch_size }} ``` | Task | Metric | Value | Stderr | diff --git a/tests/e2e/models/test_lm_eval_correctness.py b/tests/e2e/models/test_lm_eval_correctness.py index 18768e1..eaef67d 100644 --- a/tests/e2e/models/test_lm_eval_correctness.py +++ b/tests/e2e/models/test_lm_eval_correctness.py @@ -69,6 +69,8 @@ def generate_report(tp_size, eval_config, report_data, report_dir, env_config): if model_args.get('enable_expert_parallel', False): parallel_mode += " + EP" + execution_model = f"{'Eager' if model_args.get('enforce_eager', False) else 'ACLGraph'}" + report_content = template.render( vllm_version=env_config.vllm_version, vllm_commit=env_config.vllm_commit, @@ -77,6 +79,7 @@ def generate_report(tp_size, eval_config, report_data, report_dir, env_config): cann_version=env_config.cann_version, torch_version=env_config.torch_version, torch_npu_version=env_config.torch_npu_version, + hardware=eval_config.get("hardware", "unknown"), model_name=eval_config["model_name"], model_args=f"'{','.join(f'{k}={v}' for k, v in model_args.items())}'", model_type=eval_config.get("model", "vllm"), @@ -84,10 +87,11 @@ def generate_report(tp_size, eval_config, report_data, report_dir, env_config): apply_chat_template=eval_config.get("apply_chat_template", True), fewshot_as_multiturn=eval_config.get("fewshot_as_multiturn", True), limit=eval_config.get("limit", "N/A"), - batch_size="auto", + batch_size=eval_config.get("batch_size", "auto"), num_fewshot=eval_config.get("num_fewshot", "N/A"), rows=report_data["rows"], - parallel_mode=parallel_mode) + parallel_mode=parallel_mode, + execution_model=execution_model) report_output = os.path.join( report_dir, f"{os.path.basename(eval_config['model_name'])}.md") @@ -110,7 +114,7 @@ def test_lm_eval_correctness_param(config_filename, tp_size, report_dir, "apply_chat_template": eval_config.get("apply_chat_template", True), "fewshot_as_multiturn": eval_config.get("fewshot_as_multiturn", True), "limit": eval_config.get("limit", None), - "batch_size": "auto", + "batch_size": eval_config.get("batch_size", "auto"), } for s in ["num_fewshot", "fewshot_as_multiturn", "apply_chat_template"]: val = eval_config.get(s, None) diff --git a/tests/e2e/multicard/test_expert_parallel.py b/tests/e2e/multicard/test_expert_parallel.py index e956ed6..288afdd 100644 --- a/tests/e2e/multicard/test_expert_parallel.py +++ b/tests/e2e/multicard/test_expert_parallel.py @@ -14,14 +14,24 @@ def test_e2e_ep_correctness(model_name): ] max_tokens = 5 - with VllmRunner(model_name, tensor_parallel_size=2, - enforce_eager=True) as vllm_model: + # FIXME: Really strange that chunked prefill might lead to different results, investigate further + with VllmRunner( + model_name, + tensor_parallel_size=2, + additional_config={"ascend_scheduler_config": { + "enabled": True + }}, + enforce_eager=True) as vllm_model: tp_output = vllm_model.generate_greedy(example_prompts, max_tokens) - with VllmRunner(model_name, - tensor_parallel_size=2, - enable_expert_parallel=True, - enforce_eager=True) as vllm_model: + with VllmRunner( + model_name, + tensor_parallel_size=2, + enable_expert_parallel=True, + additional_config={"ascend_scheduler_config": { + "enabled": True + }}, + enforce_eager=True) as vllm_model: ep_output = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index a90c864..f3348d8 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -23,6 +23,7 @@ Run `pytest tests/test_offline_inference.py`. import os from unittest.mock import patch +import pytest from modelscope import snapshot_download # type: ignore from vllm import SamplingParams @@ -30,6 +31,15 @@ from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +QWEN_DENSE_MODELS = [ + "vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8" +] + +DEEPSEEK_W4A8_MODELS = [ + "vllm-ascend/DeepSeek-V3-W4A8-Pruing", + "vllm-ascend/DeepSeek-V3.1-W4A8-puring" +] + def test_models_distributed_QwQ(): example_prompts = [ @@ -61,8 +71,8 @@ def test_models_distributed_DeepSeek_multistream_moe(): additional_config={ "torchair_graph_config": { "enabled": True, - "enable_multistream_moe": True, }, + "enable_multistream_moe": True, "ascend_scheduler_config": { "enabled": True, }, @@ -104,14 +114,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC(): vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS) @patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"}) -def test_models_distributed_DeepSeek_W4A8DYNAMIC(): +def test_models_distributed_DeepSeek_W4A8DYNAMIC(model): prompts = [ "Hello, my name is", ] max_tokens = 5 with VllmRunner( - snapshot_download("vllm-ascend/DeepSeek-V3-W4A8-Pruing"), + snapshot_download(model), dtype="auto", tensor_parallel_size=2, quantization="ascend", @@ -150,3 +161,46 @@ def test_sp_for_qwen3_moe() -> None: enable_expert_parallel=True, enforce_eager=True) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) +def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"}) +def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight( + model, enforce_eager): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/multicard/test_prefix_caching.py b/tests/e2e/multicard/test_prefix_caching.py index 73d0d2c..e563488 100644 --- a/tests/e2e/multicard/test_prefix_caching.py +++ b/tests/e2e/multicard/test_prefix_caching.py @@ -116,20 +116,22 @@ def test_prefix_cache_with_ascend_scheduler(model: str, prefix_cache_output = vllm_model.generate_greedy( INPUT_PROMPTS, max_tokens) - with VllmRunner(model, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - 'enable_prefix_caching': True, - "enable_chunked_prefill": True, - }, - }, - enforce_eager=True, - max_model_len=2048, - tensor_parallel_size=2, - gpu_memory_utilization=0.7) as vllm_model: - chunk_prefill_prefix_cache_output = vllm_model.generate_greedy( - INPUT_PROMPTS, max_tokens) + # TODO: enable apc and chunked prefill with ascend scheduler will lead accuracy problem. + # Disable it now. Fix it or drop the ascend scheduler in the future. + # with VllmRunner(model, + # additional_config={ + # 'ascend_scheduler_config': { + # 'enabled': True, + # 'enable_prefix_caching': True, + # "enable_chunked_prefill": True, + # }, + # }, + # enforce_eager=True, + # max_model_len=2048, + # tensor_parallel_size=2, + # gpu_memory_utilization=0.7) as vllm_model: + # chunk_prefill_prefix_cache_output = vllm_model.generate_greedy( + # INPUT_PROMPTS, max_tokens) check_outputs_equal( outputs_0_lst=vllm_output, @@ -138,9 +140,9 @@ def test_prefix_cache_with_ascend_scheduler(model: str, name_1="prefix_cache_output", ) - check_outputs_equal( - outputs_0_lst=chunk_prefill_prefix_cache_output, - outputs_1_lst=prefix_cache_output, - name_0="chunk_prefill_prefix_cache_output", - name_1="prefix_cache_output", - ) + # check_outputs_equal( + # outputs_0_lst=chunk_prefill_prefix_cache_output, + # outputs_1_lst=prefix_cache_output, + # name_0="chunk_prefill_prefix_cache_output", + # name_1="prefix_cache_output", + # ) diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py index 13e1fa3..6e3da1f 100644 --- a/tests/e2e/multicard/test_qwen3_moe.py +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -66,7 +66,6 @@ def test_models_distributed_Qwen3_MOE_W8A8(): max_model_len=8192, tensor_parallel_size=2, quantization="ascend", - enforce_eager=True, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index 1eb9d2f..de84861 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -22,6 +22,8 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`. import os from typing import Dict +import pytest + from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" @@ -153,6 +155,7 @@ def _pangu_torchair_test_fixture( print(f"Generated text: {vllm_output[i][1]!r}") +@pytest.mark.skip("skipping test_e2e_pangu_with_torchair") def test_e2e_pangu_with_torchair(): additional_config = { "torchair_graph_config": { diff --git a/tests/e2e/multicard/test_weight_loader.py b/tests/e2e/multicard/test_weight_loader.py new file mode 100644 index 0000000..f59cd1f --- /dev/null +++ b/tests/e2e/multicard/test_weight_loader.py @@ -0,0 +1,188 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Compare the outputs of vLLM with and without aclgraph. + +Run `pytest tests/multicard/test_external_launcher.py`. +""" + +import os +import subprocess +import sys + +import pytest +import torch_npu + +MOE_MODELS = ["Qwen/Qwen3-30B-A3B"] +MODELS = ["Qwen/Qwen3-8B"] +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] + + +@pytest.mark.parametrize("model", MOE_MODELS) +def test_external_launcher_eager(model): + script = script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py" + env = os.environ.copy() + # TODO: Change to 2 when ci machine has 4 cards + cmd = [ + sys.executable, + str(script), + "--model", + model, + "--tp-size", + "2", + "--proc-per-node", + "2", + "--trust-remote-code", + "--enforce-eager", + "--enable-expert-parallel", + "--enable-sleep-mode", + "--model-weight-gib", + "20", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600, + ) + output = proc.stdout.decode() + + print(output) + + assert "TP RANKS: [0]" in output + assert "TP RANKS: [1]" in output + assert "Generated text:" in output + assert proc.returncode == 0 + + +@pytest.mark.parametrize("model", MOE_MODELS) +def test_external_launcher_aclgraph(model): + script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py" + env = os.environ.copy() + # TODO: Change to 2 when ci machine has 4 cards + cmd = [ + sys.executable, + str(script), + "--model", + model, + "--tp-size", + "2", + "--proc-per-node", + "2", + "--trust-remote-code", + "--enable-expert-parallel", + "--enable-sleep-mode", + "--model-weight-gib", + "20", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600, + ) + output = proc.stdout.decode() + + print(output) + + assert "TP RANKS: [0]" in output + assert "TP RANKS: [1]" in output + assert "Generated text:" in output + assert proc.returncode == 0 + + +@pytest.mark.parametrize("model", MODELS) +def test_external_launcher_dense(model): + script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py" + env = os.environ.copy() + # TODO: Change to 2 when ci machine has 4 cards + cmd = [ + sys.executable, + str(script), + "--model", + model, + "--tp-size", + "2", + "--proc-per-node", + "2", + "--trust-remote-code", + "--enable-sleep-mode", + "--model-weight-gib", + "20", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600, + ) + output = proc.stdout.decode() + + print(output) + + assert "TP RANKS: [0]" in output + assert "TP RANKS: [1]" in output + assert "Generated text:" in output + assert proc.returncode == 0 + + +@pytest.mark.parametrize("model", MODELS) +def test_external_launcher_dense_eager(model): + script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py" + env = os.environ.copy() + # TODO: Change to 2 when ci machine has 4 cards + cmd = [ + sys.executable, + str(script), + "--model", + model, + "--tp-size", + "2", + "--proc-per-node", + "2", + "--trust-remote-code", + "--enforce-eager", + "--enable-sleep-mode", + "--model-weight-gib", + "20", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600, + ) + output = proc.stdout.decode() + + print(output) + + assert "TP RANKS: [0]" in output + assert "TP RANKS: [1]" in output + assert "Generated text:" in output + assert proc.returncode == 0 diff --git a/tests/e2e/pd_disaggreate/run_edge_case_test.sh b/tests/e2e/pd_disaggreate/run_edge_case_test.sh index a086df0..49e06e5 100644 --- a/tests/e2e/pd_disaggreate/run_edge_case_test.sh +++ b/tests/e2e/pd_disaggreate/run_edge_case_test.sh @@ -70,7 +70,7 @@ run_tests_for_model() { # Start prefill instance PREFILL_PORT=8001 - BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_LLMDD_RPC_PORT=5559 vllm serve $model_name \ + BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_ASCEND_LLMDD_RPC_PORT=5559 vllm serve $model_name \ --port $PREFILL_PORT \ --seed 1024 \ --enforce-eager \ @@ -90,7 +90,7 @@ run_tests_for_model() { DECODE_PORT=8002 # Build the command with or without model-specific args - BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_LLMDD_RPC_PORT=6000 vllm serve $model_name \ + BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_ASCEND_LLMDD_RPC_PORT=6000 vllm serve $model_name \ --port $DECODE_PORT \ --seed 1024 \ --enforce-eager \ diff --git a/tests/e2e/run_doctests.sh b/tests/e2e/run_doctests.sh index 2b00b64..70e2dad 100755 --- a/tests/e2e/run_doctests.sh +++ b/tests/e2e/run_doctests.sh @@ -22,7 +22,6 @@ set -eo errexit . $(dirname "$0")/common.sh export VLLM_USE_MODELSCOPE=true -export VLLM_LOGGING_LEVEL=ERROR _info "====> Start Quickstart test" . "${SCRIPT_DIR}/doctests/001-quickstart-test.sh" diff --git a/tests/e2e/singlecard/ops/test_bgmv_expand.py b/tests/e2e/singlecard/ops/test_bgmv_expand.py index 0aca9ca..9d82ab8 100644 --- a/tests/e2e/singlecard/ops/test_bgmv_expand.py +++ b/tests/e2e/singlecard/ops/test_bgmv_expand.py @@ -33,8 +33,8 @@ def test_bgmv_expand(): y_npu = y.npu() y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128) - y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0, - 128) + y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu, + y_npu, 0, 128) # Compare the results. torch.testing.assert_close(y_out_npu.cpu(), diff --git a/tests/e2e/singlecard/ops/test_bgmv_shrink.py b/tests/e2e/singlecard/ops/test_bgmv_shrink.py index 99bb8e8..6cb8127 100644 --- a/tests/e2e/singlecard/ops/test_bgmv_shrink.py +++ b/tests/e2e/singlecard/ops/test_bgmv_shrink.py @@ -33,7 +33,7 @@ def test_bgmv_shrink(): y_npu = y.npu() y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5) - torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5) + torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5) # Compare the results. torch.testing.assert_close(y_npu.cpu(), diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index cf13010..c6da287 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -28,12 +28,12 @@ import torch import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - TokenDispatcherWithAllGather +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather NUM_EXPERTS = [8, 64] -EP_SIZE = [1, 4] +EP_SIZE = [1] TOP_KS = [2, 6] DEVICE = ["npu"] @@ -115,19 +115,6 @@ def test_token_dispatcher_with_all_gather( w1_local = w1 w2_local = w2 - if ep_size > 1: - local_e = e // ep_size - e_ids = torch.arange(local_e * 0, - local_e * (0 + 1), - device=device, - dtype=torch.int32) - expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32) - expert_map[e_ids] = torch.arange(local_e, - device=device, - dtype=torch.int32) - w1_local = w1[e_ids] - w2_local = w2[e_ids] - score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) @@ -179,6 +166,87 @@ def test_token_dispatcher_with_all_gather( torch.npu.reset_peak_memory_stats() +@pytest.mark.parametrize("m", [1, 33, 64]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICE) +def test_token_dispatcher_with_all_gather_quant( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + device: str, +): + context_mock = MagicMock() + context_mock.fused_moe_state = 0 + with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context", + return_value=context_mock): + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) + w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype) + w2 = torch.randn((e, n, k), device=device, dtype=torch.int8) + w2_scale = torch.empty((e, k), device=device, dtype=dtype) + + score = torch.randn((m, e), device=device, dtype=dtype) + expert_map = None + local_e = e + + score = torch.softmax(score, dim=-1, dtype=dtype) + topk_weights, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.to(torch.int32) + row_idx = (torch.arange( + 0, + m * topk, + device=device, + dtype=torch.int32, + ).view(topk, -1).permute(1, 0).contiguous()) + + dispatcher_kwargs = { + "num_experts": e, + "top_k": topk, + "num_local_experts": local_e, + } + dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) + + apply_router_weight_on_input = False + dispatch_output = dispatcher.token_dispatch( + hidden_states=a, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + with_quant=True) + + sorted_hidden_states = dispatch_output["hidden_states"] + group_list = dispatch_output["group_list"] + group_list_type = dispatch_output.get("group_list_type", 1) + dynamic_scale = dispatch_output["dynamic_scale"] + + expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + group_list_type=group_list_type, + dynamic_scale=dynamic_scale, + with_quant=True) + combined_output = dispatcher.token_combine(hidden_states=expert_output, + bias=None) + assert combined_output.shape == (m, k) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + @pytest.mark.parametrize("m", [1, 33, 64]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -222,7 +290,7 @@ def test_select_experts( dtype=torch.int32) custom_routing_function.return_value = (mock_weights, mock_ids) - with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk" + with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk" ) as mock_native_grouped_topk: mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) diff --git a/tests/e2e/singlecard/ops/test_moe_comm.py b/tests/e2e/singlecard/ops/test_moe_comm.py deleted file mode 100644 index b034ed4..0000000 --- a/tests/e2e/singlecard/ops/test_moe_comm.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -import gc -from types import SimpleNamespace - -import pytest -import torch - -from vllm.model_executor.layers.fused_moe.config import ( # isort: skip - FusedMoEConfig, FusedMoEParallelConfig) - -from vllm_ascend.distributed.moe_comm_method import ( # isort: skip - AllGatherCommImpl, NativeAllGatherCommImpl) - - -@pytest.mark.parametrize("num_tokens", [16, 128]) -@pytest.mark.parametrize("hidden_size", [64, 128]) -@pytest.mark.parametrize("global_num_experts", [8, 16]) -@pytest.mark.parametrize("num_local_experts", [4, 8]) -@pytest.mark.parametrize("top_k_num", [2, 4]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("ep_rank", [0, 1]) -@pytest.mark.parametrize("apply_a8_quantization", [False]) -def test_all_gather_comm_impl( - num_tokens, - hidden_size, - global_num_experts, - num_local_experts, - top_k_num, - dtype, - ep_rank, - apply_a8_quantization, - mocker, -): - """ - Tests the AllGatherCommImpl against the NativeAllGatherCommImpl. - - This test compares the outputs of the NPU-optimized AllGatherCommImpl - with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure - correctness across various configurations. - """ - if top_k_num > global_num_experts: - pytest.skip("top_k_num cannot be greater than global_num_experts") - if num_local_experts > global_num_experts: - pytest.skip( - "num_local_experts cannot be greater than global_num_experts") - - device = torch.device("npu") - - # mock get_tensor_model_parallel_rank to return ep_rank - mocker.patch( - "vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank", - return_value=ep_rank, - ) - - # make moe config - parallel_config = SimpleNamespace( - enable_expert_parallel=num_local_experts < global_num_experts) - moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( - tp_size_=max(2, global_num_experts // num_local_experts), - dp_size_=1, - vllm_parallel_config=parallel_config, - ) - - moe_config = FusedMoEConfig( - num_experts=global_num_experts, - experts_per_token=top_k_num, - hidden_dim=hidden_size, - num_local_experts=num_local_experts, - moe_parallel_config=moe_parallel_config, - in_dtype=dtype, - quant_config=None, # No quantization in this test - max_num_tokens=num_tokens, - ) - - # Instantiate implementations - native_impl = NativeAllGatherCommImpl(moe_config) - - all_gather_impl = AllGatherCommImpl(moe_config) - - # --- Input Data --- - hidden_states = torch.randn(num_tokens, - hidden_size, - device=device, - dtype=dtype) - topk_ids = torch.randint(0, - global_num_experts, (num_tokens, top_k_num), - device=device, - dtype=torch.int32) - topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype) - topk_weights = torch.nn.functional.softmax(topk_weights, dim=1) - - num_experts = global_num_experts - - expert_map = None - if num_local_experts < global_num_experts: - # Create a map where some experts are local and some are not - expert_map = torch.full((global_num_experts, ), -1, device=device) - expert_map[ep_rank * num_local_experts:(ep_rank + 1) * - num_local_experts] = torch.arange(num_local_experts, - device=device) - num_experts = num_local_experts - - # --- Run Native Implementation (Golden Reference) --- - native_hidden_states_out = hidden_states.clone() - ( - native_permuted_hidden, - native_expert_tokens, - _, - _, - ) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map, - num_experts, apply_a8_quantization) - # Simulate MLP output - native_mlp_output = torch.randn_like(native_permuted_hidden) - native_impl.unpermute(native_mlp_output, native_hidden_states_out) - - # --- Run AllGather Implementation --- - all_gather_hidden_states_out = hidden_states.clone() - ( - all_gather_permuted_hidden, - all_gather_expert_tokens, - _, - _, - ) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights, - expert_map, num_experts, apply_a8_quantization) - - # Use the same simulated MLP output for a fair comparison - all_gather_mlp_output = native_mlp_output.clone() - - all_gather_impl.unpermute(all_gather_mlp_output, - all_gather_hidden_states_out) - - # --- Assertions --- - # Define tolerance based on dtype - atol = 1e-3 if dtype == torch.float16 else 1e-2 - rtol = 1e-3 if dtype == torch.float16 else 1e-2 - - # 1. Compare expert_tokens from pre_process - assert torch.allclose(native_expert_tokens.to( - all_gather_expert_tokens.device), - all_gather_expert_tokens, - atol=atol, - rtol=rtol), "Expert tokens do not match." - - # 2. Compare permuted_hidden_states from pre_process - num_valid_tokens = native_expert_tokens.sum() - assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to( - all_gather_permuted_hidden.device), - all_gather_permuted_hidden[:num_valid_tokens], - atol=atol, - rtol=rtol), "Permuted hidden states do not match." - - # 3. Compare final hidden_states from post_process - assert torch.allclose(native_hidden_states_out.to( - all_gather_hidden_states_out.device), - all_gather_hidden_states_out, - atol=atol, - rtol=rtol), "Final hidden states do not match." - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/singlecard/ops/test_rotary_embedding.py b/tests/e2e/singlecard/ops/test_rotary_embedding.py index 6f513b2..27e9b3b 100644 --- a/tests/e2e/singlecard/ops/test_rotary_embedding.py +++ b/tests/e2e/singlecard/ops/test_rotary_embedding.py @@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim( ) ref_query, ref_key = rope.forward_native(positions, query, key) - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, @@ -239,7 +239,7 @@ class ModelwithRotaryEmbedding(nn.Module): # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph qkv = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(3, dim=-1) - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, q, k, @@ -299,7 +299,7 @@ def test_capture_rotary_embedding_in_aclgraph( # Validate if the rotary_embedding custom kernel is indeed inside the graph by # string match graph = str(gm.graph) - assert "_C.rotary_embedding" in graph + assert "_C_ascend.rotary_embedding" in graph return gm static_positions = torch.randint(0, max_position_embeddings, diff --git a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py index 54d1127..64b974d 100644 --- a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py +++ b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py @@ -72,7 +72,7 @@ def test_get_masked_input_and_mask( # Get custom op result print("input_tensor:", input_tensor) - custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask( + custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask( input_tensor, test_case["org_start"], test_case["org_end"], test_case["padding"], test_case["added_start"], test_case["added_end"]) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 0c01a07..89d636a 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,14 +1,10 @@ from __future__ import annotations -import os - import pytest from vllm import SamplingParams from tests.e2e.conftest import VllmRunner -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - @pytest.fixture def sampling_config(): @@ -20,9 +16,10 @@ def model_name(): return "wemaster/deepseek_mtp_main_random_bf16" -def test_mtp_correctness( +def mtp_correctness( sampling_config: SamplingParams, model_name: str, + num_speculative_tokens: int, ): example_prompts = [ "Hello, my name is", @@ -38,7 +35,7 @@ def test_mtp_correctness( tensor_parallel_size=1, gpu_memory_utilization=0.7, max_model_len=256, - enforce_eager=True) as ref_llm: + enforce_eager=False) as ref_llm: ref_outputs = ref_llm.generate(example_prompts, sampling_config) with VllmRunner( @@ -50,9 +47,9 @@ def test_mtp_correctness( enable_expert_parallel=True, speculative_config={ "method": "deepseek_mtp", - "num_speculative_tokens": 1, + "num_speculative_tokens": num_speculative_tokens, }, - enforce_eager=True, + enforce_eager=False, max_model_len=2000, additional_config={"ascend_scheduler_config": { "enabled": False @@ -74,3 +71,18 @@ def test_mtp_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + del spec_llm + + +def test_mtp1_correctness( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 1) + + +def test_mtp2_correctness( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 2) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py index 1bf6fea..1083557 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py @@ -1,14 +1,10 @@ from __future__ import annotations -import os - import pytest from vllm import SamplingParams from tests.e2e.conftest import VllmRunner -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - @pytest.fixture def sampling_config(): diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 9a1bfb8..0c1546d 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -99,7 +99,6 @@ def test_ngram_correctness( assert matches > int(0.7 * len(ref_outputs)) -@pytest.mark.skipif(True, reason="oom in CI, fix me") @pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) def test_eagle_correctness( test_prompts: list[list[dict[str, Any]]], @@ -111,8 +110,6 @@ def test_eagle_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. ''' - if not use_eagle3: - pytest.skip("Not current support for the test.") ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True) ref_outputs = ref_llm.chat(test_prompts, sampling_config) @@ -121,7 +118,6 @@ def test_eagle_correctness( spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() with VllmRunner( model_name, - trust_remote_code=True, enable_chunked_prefill=True, max_num_seqs=1, max_num_batched_tokens=2048, diff --git a/tests/e2e/singlecard/test_ascend_scheduler.py b/tests/e2e/singlecard/test_ascend_scheduler.py index 1a47ab6..b6ab3f3 100644 --- a/tests/e2e/singlecard/test_ascend_scheduler.py +++ b/tests/e2e/singlecard/test_ascend_scheduler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +from vllm import SamplingParams from tests.e2e.conftest import VllmRunner from tests.e2e.model_utils import check_outputs_equal @@ -86,3 +87,25 @@ def test_chunked_prefill_with_ascend_scheduler( name_0="vllm_output", name_1="chunked_prefill_output", ) + + +def test_async_scheduling() -> None: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 10 + sampling_params = SamplingParams(temperature=0.2, + max_tokens=10, + stop_token_ids=None) + + with VllmRunner( + "Qwen/Qwen2.5-0.5B-Instruct", + max_model_len=4096, + max_num_seqs=50, + dtype="bfloat16", + gpu_memory_utilization=0.9, + async_scheduling=True, + ) as vllm_model: + vllm_model.generate(prompts, sampling_params=sampling_params) diff --git a/tests/e2e/singlecard/test_guided_decoding.py b/tests/e2e/singlecard/test_guided_decoding.py index 6cb1c7b..ac2426e 100644 --- a/tests/e2e/singlecard/test_guided_decoding.py +++ b/tests/e2e/singlecard/test_guided_decoding.py @@ -17,17 +17,23 @@ # limitations under the License. # import json -import os +from typing import Any, Dict import jsonschema import pytest import regex as re + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams, SamplingParams +else: + from vllm.sampling_params import SamplingParams, StructuredOutputsParams + from vllm.outputs import RequestOutput -from vllm.sampling_params import GuidedDecodingParams, SamplingParams from tests.e2e.conftest import VllmRunner -os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" MODEL_NAME = "Qwen/Qwen3-0.6B" GuidedDecodingBackend = ["xgrammar", "guidance", "outlines"] @@ -84,16 +90,29 @@ def sample_json_schema(): @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) def test_guided_json_completion(guided_decoding_backend: str, sample_json_schema): - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=500, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)) - - with VllmRunner( - MODEL_NAME, - seed=0, - guided_decoding_backend=guided_decoding_backend, - ) as vllm_model: + runner_kwargs: Dict[str, Any] = {} + if vllm_version_is("0.10.2"): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=500, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + runner_kwargs = { + "seed": 0, + "guided_decoding_backend": guided_decoding_backend, + } + else: + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=500, + structured_outputs=StructuredOutputsParams( + json=sample_json_schema)) + runner_kwargs = { + "seed": 0, + "structured_outputs_config": { + "backend": guided_decoding_backend + }, + } + with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model: prompts = [ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" @@ -121,17 +140,29 @@ def test_guided_json_completion(guided_decoding_backend: str, def test_guided_regex(guided_decoding_backend: str, sample_regex): if guided_decoding_backend == "outlines": pytest.skip("Outlines doesn't support regex-based guided decoding.") + runner_kwargs: Dict[str, Any] = {} + if vllm_version_is("0.10.2"): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + runner_kwargs = { + "seed": 0, + "guided_decoding_backend": guided_decoding_backend, + } + else: + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + structured_outputs=StructuredOutputsParams(regex=sample_regex)) + runner_kwargs = { + "seed": 0, + "structured_outputs_config": { + "backend": guided_decoding_backend + }, + } - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams(regex=sample_regex)) - - with VllmRunner( - MODEL_NAME, - seed=0, - guided_decoding_backend=guided_decoding_backend, - ) as vllm_model: + with VllmRunner(MODEL_NAME, **runner_kwargs) as vllm_model: prompts = [ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2 diff --git a/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py b/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py new file mode 100644 index 0000000..0f150c8 --- /dev/null +++ b/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Compare the outputs of vLLM with multistream_overlap_shared_expert +enabled and disabled. + +Run `pytest tests/e2e/singlecard/test_multistream_overlap_shared_expert.py`. +""" + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + +MODELS = [ + "Qwen/Qwen3-0.6B", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +def test_models_with_multistream_overlap_shared_expert( + model: str, + max_tokens: int, +) -> None: + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + + sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + additional_config={ + "multistream_overlap_shared_expert": True, + }, + ) as runner: + vllm_moe_ms_eager_outputs = runner.model.generate( + prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=False, + additional_config={ + "multistream_overlap_shared_expert": True, + }, + ) as runner: + vllm_moe_ms_aclgraph_outputs = runner.model.generate( + prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_moe_ms_eager_outputs_list = [] + for output in vllm_moe_ms_eager_outputs: + vllm_moe_ms_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_moe_ms_aclgraph_outputs_list = [] + for output in vllm_moe_ms_aclgraph_outputs: + vllm_moe_ms_aclgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_moe_ms_eager_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_moe_ms_eager_outputs", + ) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_moe_ms_aclgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_moe_ms_aclgraph_outputs", + ) diff --git a/tests/e2e/singlecard/test_vlm.py b/tests/e2e/singlecard/test_vlm.py index 5fe27f6..59fb10e 100644 --- a/tests/e2e/singlecard/test_vlm.py +++ b/tests/e2e/singlecard/test_vlm.py @@ -20,19 +20,14 @@ Run `pytest tests/test_offline_inference.py`. """ -import os -import pytest from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from tests.e2e.conftest import VllmRunner -os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" - -@pytest.mark.skip(reason="fix me") def test_multimodal_vl(prompt_template): image = ImageAsset("cherry_blossom") \ .pil_image.convert("RGB") @@ -52,9 +47,12 @@ def test_multimodal_vl(prompt_template): "fps": 1, }, enforce_eager=True) as vllm_model: - vllm_model.generate_greedy(prompts=prompts, - images=images, - max_tokens=64) + outputs = vllm_model.generate_greedy(prompts=prompts, + images=images, + max_tokens=64) + assert len(outputs) == len(prompts) + for _, output_str in outputs: + assert output_str, "Generated output should not be empty." def test_multimodal_audio(): @@ -86,4 +84,7 @@ def test_multimodal_audio(): dtype="bfloat16", limit_mm_per_prompt={"audio": 2}, gpu_memory_utilization=0.9) as runner: - runner.generate(inputs, sampling_params=sampling_params) + outputs = runner.generate(inputs, sampling_params=sampling_params) + + assert outputs is not None, "Generated outputs should not be None." + assert len(outputs) > 0, "Generated outputs should not be empty." diff --git a/tests/e2e/vllm_interface/singlecard/test_sampler.py b/tests/e2e/vllm_interface/singlecard/test_sampler.py new file mode 100644 index 0000000..662e76e --- /dev/null +++ b/tests/e2e/vllm_interface/singlecard/test_sampler.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner + + +def test_models_topk() -> None: + example_prompts = [ + "The capital of France is", + ] + sampling_params = SamplingParams(max_tokens=10, + temperature=0.0, + top_k=10, + top_p=0.9) + + with VllmRunner("Qwen/Qwen3-0.6B", + max_model_len=4096, + gpu_memory_utilization=0.7) as runner: + runner.generate(example_prompts, sampling_params) diff --git a/tests/e2e/vllm_interface/vllm_test.cfg b/tests/e2e/vllm_interface/vllm_test.cfg new file mode 100644 index 0000000..4d077b0 --- /dev/null +++ b/tests/e2e/vllm_interface/vllm_test.cfg @@ -0,0 +1,2 @@ +# Base docker image used to build the vllm-ascend e2e test image, which is built in the vLLM repository +BASE_IMAGE_NAME="quay.io/ascend/cann:8.2.rc1-910b-ubuntu22.04-py3.11" diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 556c8d7..d553637 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -7,8 +7,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, AscendAttentionBackendImpl, AscendAttentionMetadataBuilder, AscendAttentionState, - AscendMetadata, - CommonAttentionState) + AscendMetadata) from vllm_ascend.attention.utils import AscendCommonAttentionMetadata @@ -25,10 +24,6 @@ class TestAscendAttentionBackend(TestBase): self.assertEqual(AscendAttentionBackend.get_metadata_cls(), AscendMetadata) - def test_get_state_cls(self): - self.assertEqual(AscendAttentionBackend.get_state_cls(), - CommonAttentionState) - def test_get_builder_cls(self): self.assertEqual(AscendAttentionBackend.get_builder_cls(), AscendAttentionMetadataBuilder) @@ -72,7 +67,8 @@ class TestAscendAttentionMetadataBuilder(TestBase): self.mock_vllm_config.model_config.max_model_len = 640 self.mock_vllm_config.cache_config.block_size = 64 self.mock_device = 'cpu:0' - self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, + self.builder = AscendAttentionMetadataBuilder(None, None, + self.mock_vllm_config, self.mock_device) def test_reorder_batch(self): @@ -100,19 +96,21 @@ class TestAscendAttentionMetadataBuilder(TestBase): max_query_len=5, decode_token_per_req=torch.tensor([1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1]), positions=torch.tensor([10, 10]), attn_mask=torch.ones((10, 10)), spec_attn_mask=None, - attn_state=AscendAttentionState.PrefillNoCache) + attn_state=AscendAttentionState.PrefillNoCache, + num_computed_tokens_cpu=None, + seq_lens=None) mock_nz_tensor = MagicMock() mock_model = MagicMock() mock_nd_to_nz_2d.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(common_attn_metadata, mock_model) + self.builder.build(1, common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('torch_npu.npu_format_cast') @@ -131,12 +129,14 @@ class TestAscendAttentionMetadataBuilder(TestBase): max_query_len=6, decode_token_per_req=torch.tensor([1, 1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill) + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) mock_ascend_attention_state = MagicMock() mock_ascend_attention_state.PrefillNoCache = 0 @@ -146,7 +146,7 @@ class TestAscendAttentionMetadataBuilder(TestBase): mock_nd_to_nz_spec.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(common_attn_metadata, mock_model) + self.builder.build(1, common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @@ -160,15 +160,17 @@ class TestAscendAttentionMetadataBuilder(TestBase): max_query_len=6, decode_token_per_req=torch.tensor([1, 1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill) + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) mock_model = MagicMock() - self.builder.build(common_attn_metadata, mock_model) + self.builder.build(1, common_attn_metadata, mock_model) class TestAscendAttentionBackendImpl(TestBase): @@ -341,36 +343,6 @@ class TestAscendAttentionBackendImpl(TestBase): mock_flash_attention.assert_called_once() assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_flash_attention') - def test_forward_prefill_no_cache_swa(self, mock_flash_attention, - mock_reshape_cache): - """Test forward pass in PrefillNoCache state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.PrefillNoCache - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.seq_lens = torch.tensor([10]) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - layer = self.layer_no_quant - # layer.quant_method.apply.return_value = metadata - print(self.layer_no_quant._v_scale_float) - output = self.impl_swa.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) - - mock_reshape_cache.assert_called_once() - mock_flash_attention.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_flash_attention_qlens') def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens, @@ -401,10 +373,12 @@ class TestAscendAttentionBackendImpl(TestBase): mock_flash_attention_qlens.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') def test_forward_decode_only(self, mock_paged_attention, - mock_npu_reshape_and_cache): + mock_npu_reshape_and_cache, + mock_get_forward_context): """Test forward pass in DecodeOnly state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -418,6 +392,8 @@ class TestAscendAttentionBackendImpl(TestBase): metadata.slot_mapping = torch.zeros(10, dtype=torch.long) layer = self.layer_no_quant + mock_get_forward_context.return_value = MagicMock(capturing=False) + output = self.impl.forward(layer, query, key, @@ -458,6 +434,44 @@ class TestAscendAttentionBackendImpl(TestBase): mock_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention') + @patch('torch_npu.npu_fused_infer_attention_score') + def test_forward_decode_only_swa_seq_len_mismatch( + self, mock_fused_infer_attention_score, mock_paged_attention, + mock_npu_reshape_and_cache, mock_get_forward_context): + """Test forward pass in DecodeOnly state when seq)len_mismatch""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([10]) # len == 1 != query.size(0)==10 + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + + mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, + 64), 1) + + mock_get_forward_context.return_value = MagicMock(capturing=False) + + output = self.impl_swa.forward(self.layer_no_quant, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + mock_fused_infer_attention_score.assert_not_called() + + assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 6360504..0164057 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -186,10 +186,39 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + ascend_config = MagicMock() with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) + + self.assertEqual(builder.block_size, + mock_vllm_config.cache_config.block_size) + self.assertEqual( + builder.chunked_prefill_enabled, + mock_vllm_config.scheduler_config.chunked_prefill_enabled) + + def test_ascend_mla_metadata_builder_spec_decode(self): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + mock_spec_config = MagicMock() + mock_spec_config.num_speculative_tokens = 3 + mock_vllm_config.speculative_config = mock_spec_config + + ascend_config = MagicMock() + with patch("vllm_ascend.attention.mla_v1.get_ascend_config", + return_value=ascend_config): + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) self.assertEqual(builder.block_size, mock_vllm_config.cache_config.block_size) @@ -207,9 +236,12 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) builder.decode_threshold = 1 input_batch = MagicMock() @@ -522,7 +554,11 @@ class TestAscendMLAImpl(TestBase): self.impl.num_kv_heads = self.impl.num_heads decode_res, prefill_res = self.impl._mla_preprocess( - hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False) + "mock_layer", + hidden_states, + kv_cache, + attn_metadata, + need_gather_q_kv=False) self.assertIsNotNone(decode_res) self.assertIsNotNone(prefill_res) diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py new file mode 100644 index 0000000..347fbd1 --- /dev/null +++ b/tests/ut/compilation/test_acl_graph.py @@ -0,0 +1,720 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from unittest.mock import MagicMock, Mock, patch + +import torch +from vllm.compilation.cuda_graph import CUDAGraphOptions +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor, ForwardContext + +from tests.ut.base import TestBase +from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper + + +class TestACLGraphEntry(TestBase): + + def test_aclgraph_entry_initialization(self): + """Test ACLGraphEntry initialization with default values""" + batch_descriptor = BatchDescriptor( + num_tokens=30, + uniform_decode=False, + ) + + entry = ACLGraphEntry(batch_descriptor=batch_descriptor) + + self.assertEqual(entry.batch_descriptor, batch_descriptor) + self.assertIsNone(entry.aclgraph) + self.assertIsNone(entry.output) + self.assertIsNone(entry.input_addresses) + + def test_aclgraph_entry_with_values(self): + """Test ACLGraphEntry initialization with specified values""" + batch_descriptor = BatchDescriptor( + num_tokens=30, + uniform_decode=False, + ) + + mock_graph = MagicMock() + mock_output = MagicMock() + input_addresses = [12345, 67890] + + entry = ACLGraphEntry(batch_descriptor=batch_descriptor, + aclgraph=mock_graph, + output=mock_output, + input_addresses=input_addresses) + + self.assertEqual(entry.batch_descriptor, batch_descriptor) + self.assertEqual(entry.aclgraph, mock_graph) + self.assertEqual(entry.output, mock_output) + self.assertEqual(entry.input_addresses, input_addresses) + + +class TestACLGraphWrapper(TestBase): + + def setUp(self): + """Set up test fixtures""" + super().setUp() + + # Mock VllmConfig + self.mock_vllm_config = MagicMock(spec=VllmConfig) + self.mock_vllm_config.compilation_config = MagicMock() + + # Mock runnable function + self.mock_runnable = MagicMock(return_value="test_output") + + # Mock graph pool + self.mock_graph_pool = MagicMock() + + # Mock CUDAGraphOptions + self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions) + self.mock_cudagraph_options.debug_log_enable = False + self.mock_cudagraph_options.gc_disable = False + self.mock_cudagraph_options.weak_ref_output = False + + # Mock BatchDescriptor + self.mock_batch_descriptor = BatchDescriptor( + num_tokens=30, + uniform_decode=False, + ) + + # Mock ForwardContext + self.mock_forward_context = MagicMock(spec=ForwardContext) + self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + def test_initialization_with_default_options(self, mock_envs, + mock_current_platform): + """Test ACLGraphWrapper initialization with default CUDAGraphOptions""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + + wrapper = ACLGraphWrapper(runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool) + + self.assertEqual(wrapper.runnable, self.mock_runnable) + self.assertEqual(wrapper.vllm_config, self.mock_vllm_config) + self.assertEqual(wrapper.graph_pool, self.mock_graph_pool) + self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL) + self.assertFalse(wrapper.is_debugging_mode) + self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions) + self.assertEqual(wrapper.concrete_aclgraph_entries, {}) + + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + def test_initialization_with_custom_options(self, mock_envs, + mock_current_platform): + """Test ACLGraphWrapper initialization with custom CUDAGraphOptions""" + mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + self.assertEqual(wrapper.runnable, self.mock_runnable) + self.assertEqual(wrapper.vllm_config, self.mock_vllm_config) + self.assertEqual(wrapper.graph_pool, self.mock_graph_pool) + self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL) + self.assertTrue(wrapper.is_debugging_mode) + self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options) + self.assertEqual(wrapper.concrete_aclgraph_entries, {}) + + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + def test_initialization_assertion_error(self, mock_envs, + mock_current_platform): + """Test ACLGraphWrapper initialization raises AssertionError for NONE mode""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + + with self.assertRaises(AssertionError): + ACLGraphWrapper(runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.NONE, + graph_pool=self.mock_graph_pool) + + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + def test_call_with_none_runtime_mode(self, mock_envs, + mock_current_platform, + mock_get_forward_context): + """Test __call__ method when runtime mode is NONE""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + result = wrapper("arg1", "arg2") + + # Should call the runnable directly without graph capture + self.mock_runnable.assert_called_once_with("arg1", "arg2") + self.assertEqual(result, "test_output") + + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + def test_call_with_mismatched_runtime_mode(self, mock_envs, + mock_current_platform, + mock_get_forward_context): + """Test __call__ method when runtime mode doesn't match wrapper mode""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + result = wrapper("arg1", "arg2") + + # Should call the runnable directly without graph capture + self.mock_runnable.assert_called_once_with("arg1", "arg2") + self.assertEqual(result, "test_output") + + @patch('vllm_ascend.compilation.acl_graph.torch') + @patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ) + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.compilation_counter') + @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + def test_call_capture_graph_first_time( + self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, + mock_current_platform, mock_get_forward_context, + mock_validate_cudagraph_capturing_enabled, mock_torch): + """Test __call__ method captures graph for the first time""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Mock weak_ref_tensors to return the same output + mock_weak_ref_tensors.return_value = "weak_ref_output" + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Set up the compilation counter mock + mock_compilation_counter.num_cudagraph_captured = 0 + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Create a real torch tensor for the test, not a mock + test_tensor = torch.tensor([1, 2, 3]) + + # Call the wrapper + result = wrapper(test_tensor, "arg2") + + # Verify graph capture happened + mock_validate_cudagraph_capturing_enabled.assert_called_once() + mock_torch.npu.NPUGraph.assert_called_once() + mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, + pool=self.mock_graph_pool) + self.mock_runnable.assert_called_once_with(test_tensor, "arg2") + + # Verify the entry was created and updated + self.assertIn(self.mock_batch_descriptor, + wrapper.concrete_aclgraph_entries) + entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor] + self.assertEqual(entry.aclgraph, mock_npu_graph) + self.assertEqual(entry.output, "weak_ref_output") + + # Verify compilation counter was incremented + self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1) + + # Should return the original output (not weak ref) + self.assertEqual(result, "test_output") + + @patch('vllm_ascend.compilation.acl_graph.torch') + @patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ) + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.compilation_counter') + @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + def test_call_replay_graph(self, mock_weak_ref_tensors, + mock_compilation_counter, mock_envs, + mock_current_platform, mock_get_forward_context, + mock_validate_cudagraph_capturing_enabled, + mock_torch): + """Test __call__ method replays graph when already captured""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Mock weak_ref_tensors to return the same output + mock_weak_ref_tensors.return_value = "weak_ref_output" + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Set up the compilation counter mock + mock_compilation_counter.num_cudagraph_captured = 0 + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Create a real torch tensor for the test, not a mock + test_tensor = torch.tensor([1, 2, 3]) + + # First call to capture the graph + first_result = wrapper(test_tensor, "arg2") + + # Verify graph capture happened during first call + mock_validate_cudagraph_capturing_enabled.assert_called_once() + mock_torch.npu.NPUGraph.assert_called_once() + mock_torch.npu.graph.assert_called_once() + + # Reset mock to track second call + self.mock_runnable.reset_mock() + mock_npu_graph.reset_mock() + + # Second call should replay the graph + second_result = wrapper(test_tensor, "arg2") + + # Verify runnable was called only during capture (not during replay) + self.mock_runnable.assert_not_called() + + # Verify graph replay happened + mock_npu_graph.replay.assert_called_once() + + # Both calls should return the weak ref output + self.assertEqual(first_result, "test_output") # Original output + self.assertEqual(second_result, "weak_ref_output") # Weak ref output + + @patch('vllm_ascend.compilation.acl_graph.torch') + @patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ) + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + def test_call_with_debug_mode_input_address_check( + self, mock_weak_ref_tensors, mock_envs, mock_current_platform, + mock_get_forward_context, + mock_validate_cudagraph_capturing_enabled, mock_torch): + """Test __call__ method with debug mode input address checking""" + mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Mock weak_ref_tensors + mock_weak_ref_tensors.return_value = "weak_ref_output" + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Create a mock tensor as the output of runnable + mock_output_tensor = torch.tensor([4, 5, 6]) + self.mock_runnable.return_value = mock_output_tensor + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # First call to capture the graph + tensor = torch.tensor([1, 2, 3]) # Create tensor once + _ = wrapper(tensor, "arg2") + + # Second call with same tensor addresses should work + _ = wrapper(tensor, "arg2") # Use the same tensor object + + # Should not raise AssertionError + self.assertTrue(True) + + @patch('vllm_ascend.compilation.acl_graph.torch') + @patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ) + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + def test_call_with_debug_mode_input_address_mismatch( + self, mock_weak_ref_tensors, mock_envs, mock_current_platform, + mock_get_forward_context, + mock_validate_cudagraph_capturing_enabled, mock_torch): + """Test __call__ method with debug mode input address mismatch raises AssertionError""" + mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Mock weak_ref_tensors + mock_weak_ref_tensors.return_value = "weak_ref_output" + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Create a mock tensor as the output of runnable + mock_output_tensor = torch.tensor([4, 5, 6]) + self.mock_runnable.return_value = mock_output_tensor + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # First call to capture the graph + tensor1 = torch.tensor([1, 2, 3]) + _ = wrapper(tensor1, "arg2") + + # Second call with different tensor addresses should raise AssertionError + tensor2 = torch.tensor([4, 5, + 6]) # Different values, different address + + with self.assertRaises(AssertionError) as context: + wrapper(tensor2, "arg2") + + self.assertIn("Input addresses for aclgraphs are different", + str(context.exception)) + + @patch('vllm_ascend.compilation.acl_graph.torch') + @patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ) + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.compilation_counter') + @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + @patch('vllm_ascend.compilation.acl_graph.patch') + def test_call_capture_graph_with_gc_disable( + self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter, + mock_envs, mock_current_platform, mock_get_forward_context, + mock_validate_cudagraph_capturing_enabled, mock_torch): + """Test __call__ method captures graph with gc_disable option enabled""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Enable gc_disable option + self.mock_cudagraph_options.gc_disable = True + # weak_ref_output is not enabled by default + + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Mock patch context manager + mock_exit_stack = MagicMock() + mock_patch.return_value = mock_exit_stack + mock_exit_stack.enter_context = Mock() + + # Mock weak_ref_tensors to simulate the actual behavior: + # 1. First call (inside the graph context) should return "inner_output" + # 2. Second call (for entry.output) should return "weak_ref_output" + mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"] + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Set up the compilation counter mock + mock_compilation_counter.num_cudagraph_captured = 0 + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Create a real torch tensor for the test, not a mock + test_tensor = torch.tensor([1, 2, 3]) + + # Call the wrapper + result = wrapper(test_tensor, "arg2") + + # Verify patch was called to disable gc + self.assertTrue(mock_patch.called) + + # Verify graph capture happened + mock_validate_cudagraph_capturing_enabled.assert_called_once() + mock_torch.npu.NPUGraph.assert_called_once() + mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, + pool=self.mock_graph_pool) + + # Should return the original output (not weak ref) since weak_ref_output is not enabled + self.assertEqual(result, "test_output") + + @patch('vllm_ascend.compilation.acl_graph.torch') + @patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ) + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.compilation_counter') + @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + def test_call_capture_graph_with_weak_ref_output( + self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, + mock_current_platform, mock_get_forward_context, + mock_validate_cudagraph_capturing_enabled, mock_torch): + """Test __call__ method captures graph with weak_ref_output option enabled""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Enable weak_ref_output option + self.mock_cudagraph_options.weak_ref_output = True + + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Mock weak_ref_tensors to simulate the actual behavior: + # 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output" + # 2. Second call (for entry.output) should return "weak_ref_output" + mock_weak_ref_tensors.side_effect = [ + "weak_ref_output", "weak_ref_output" + ] + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Set up the compilation counter mock + mock_compilation_counter.num_cudagraph_captured = 0 + + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Create a real torch tensor for the test, not a mock + test_tensor = torch.tensor([1, 2, 3]) + + # Call the wrapper + result = wrapper(test_tensor, "arg2") + + # Verify weak_ref_tensors was called twice (once for inner output, once for final output) + self.assertEqual(mock_weak_ref_tensors.call_count, 2) + + # Verify graph capture happened + mock_validate_cudagraph_capturing_enabled.assert_called_once() + mock_torch.npu.NPUGraph.assert_called_once() + mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, + pool=self.mock_graph_pool) + + # Should return the weak ref output when weak_ref_output option is enabled + self.assertEqual(result, "weak_ref_output") + + @patch('vllm_ascend.compilation.acl_graph.get_forward_context') + @patch('vllm_ascend.compilation.acl_graph.current_platform') + @patch('vllm_ascend.compilation.acl_graph.envs') + @patch('vllm_ascend.compilation.acl_graph.logger') + def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs, + mock_current_platform, + mock_get_forward_context): + """Test __call__ method captures graph with debug logging enabled""" + mock_envs.VLLM_LOGGING_LEVEL = "INFO" + mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool + mock_get_forward_context.return_value = self.mock_forward_context + self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL + + # Enable debug logging + self.mock_cudagraph_options.debug_log_enable = True + # weak_ref_output is not enabled by default + + # Mock torch + with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch: + # Mock torch.npu.NPUGraph + mock_npu_graph = MagicMock() + mock_torch.npu.NPUGraph.return_value = mock_npu_graph + + # Mock torch.npu.graph context manager + mock_graph_context = MagicMock() + mock_torch.npu.graph.return_value = mock_graph_context + mock_graph_context.__enter__ = Mock(return_value=None) + mock_graph_context.__exit__ = Mock(return_value=None) + + # Ensure torch.Tensor can be correctly identified by isinstance + mock_torch.Tensor = torch.Tensor + + # Mock weak_ref_tensors + with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors' + ) as mock_weak_ref_tensors: + # Mock weak_ref_tensors to simulate the actual behavior: + # 1. First call (inside the graph context) should return "inner_output" + # 2. Second call (for entry.output) should return "weak_ref_output" + mock_weak_ref_tensors.side_effect = [ + "inner_output", "weak_ref_output" + ] + + # Mock validate_cudagraph_capturing_enabled + with patch( + 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' + ): + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Create a real torch tensor for the test, not a mock + test_tensor = torch.tensor([1, 2, 3]) + + # Call the wrapper + _ = wrapper(test_tensor, "arg2") + + # Verify debug log was called + mock_logger.debug.assert_called_once() + + def test_getattr_access_runnable_attributes(self): + """Test __getattr__ method accesses runnable attributes""" + mock_runnable = MagicMock() + mock_runnable.test_attr = "test_value" + + wrapper = ACLGraphWrapper( + runnable=mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Should be able to access attributes of the runnable + self.assertEqual(wrapper.test_attr, "test_value") + + def test_getattr_attribute_not_exists(self): + """Test __getattr__ method raises AttributeError for non-existent attributes""" + + # Create a simple object without any attributes + class EmptyRunnable: + pass + + mock_runnable = EmptyRunnable() + + wrapper = ACLGraphWrapper( + runnable=mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + # Should raise AttributeError for non-existent attributes + with self.assertRaises(AttributeError) as context: + _ = wrapper.non_existent_attr + + self.assertIn("Attribute non_existent_attr not exists", + str(context.exception)) + + def test_unwrap_method(self): + """Test unwrap method returns the original runnable""" + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, + vllm_config=self.mock_vllm_config, + runtime_mode=CUDAGraphMode.FULL, + graph_pool=self.mock_graph_pool, + cudagraph_options=self.mock_cudagraph_options) + + unwrapped = wrapper.unwrap() + self.assertEqual(unwrapped, self.mock_runnable) diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index df36b52..84fd643 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -27,7 +27,6 @@ class TestAscendSchedulerConfig(TestBase): max_model_len=8192, is_multimodal_model=False, send_delta_data=False, - scheduler_delay_factor=0, ) def test_initialize_from_config_with_default(self): @@ -36,7 +35,6 @@ class TestAscendSchedulerConfig(TestBase): self.basic_scheduler_config, {}) self.assertEqual(ascend_config.enable_chunked_prefill, False) self.assertEqual(ascend_config.policy, "fcfs") - self.assertEqual(ascend_config.num_scheduler_steps, 1) self.assertEqual(ascend_config.scheduler_cls, "vllm_ascend.core.scheduler.AscendScheduler") self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192) @@ -49,19 +47,21 @@ class TestAscendSchedulerConfig(TestBase): AscendSchedulerConfig( enable_chunked_prefill=False, policy="fcfs", - num_scheduler_steps=1, scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler", max_num_batched_tokens=2048, max_model_len=2048, + max_long_partial_prefills=1, + long_prefill_token_threshold=512, ), ) self.assertEqual(ascend_config.enable_chunked_prefill, False) self.assertEqual(ascend_config.policy, "fcfs") - self.assertEqual(ascend_config.num_scheduler_steps, 1) self.assertEqual(ascend_config.scheduler_cls, "vllm_ascend.core.scheduler.AscendScheduler") self.assertEqual(ascend_config.max_num_batched_tokens, 2048) self.assertEqual(ascend_config.encoder_cache_size, 2048) + self.assertEqual(ascend_config.max_long_partial_prefills, 1) + self.assertEqual(ascend_config.long_prefill_token_threshold, 512) def test_not_implemented_policy(self): with self.assertRaises(NotImplementedError) as context: @@ -78,28 +78,6 @@ class TestAscendSchedulerConfig(TestBase): str(context.exception), ) - def test_not_implemented_multimodal(self): - with self.assertRaises(NotImplementedError) as context: - AscendSchedulerConfig.initialize_from_config( - SchedulerConfig(is_multimodal_model=True), {}) - self.assertIn("currently AscendScheduler only supports LLM models", - str(context.exception)) - - def test_not_implemented_multi_step(self): - with self.assertRaises(NotImplementedError) as context: - AscendSchedulerConfig.initialize_from_config( - self.basic_scheduler_config, - AscendSchedulerConfig( - num_scheduler_steps=2, - max_num_batched_tokens=2048, - max_model_len=2048, - ), - ) - self.assertIn( - "currently AscendScheduler doesn't support multi-step", - str(context.exception), - ) - def test_not_implemented_send_delta_data(self): with self.assertRaises(NotImplementedError) as context: AscendSchedulerConfig.initialize_from_config( @@ -115,27 +93,17 @@ class TestAscendSchedulerConfig(TestBase): str(context.exception), ) - def test_not_implemented_delay_factor(self): - with self.assertRaises(NotImplementedError) as context: - AscendSchedulerConfig.initialize_from_config( - self.basic_scheduler_config, - AscendSchedulerConfig( - delay_factor=1, - max_num_batched_tokens=2048, - max_model_len=2048, - ), - ) - self.assertIn( - "currently AscendScheduler doesn't support scheduler_delay_factor", - str(context.exception), - ) - def test_no_override(self): ascend_config = AscendSchedulerConfig.initialize_from_config( self.basic_scheduler_config, {}) self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192) self.assertEqual(ascend_config.encoder_cache_size, 8192) + def test_valid_config_with_multimodal(self): + config = AscendSchedulerConfig.initialize_from_config( + SchedulerConfig(is_multimodal_model=True), {}) + self.assertTrue(config.is_multimodal_model) + def test_valid_config_with_chunked_prefill(self): ascend_config = AscendSchedulerConfig.initialize_from_config( self.basic_scheduler_config, @@ -165,3 +133,16 @@ class TestAscendSchedulerConfig(TestBase): ) self.assertIn("max_num_batched_tokens (2048)", str(context.exception)) self.assertIn("max_model_len (4096)", str(context.exception)) + + def test_initialize_from_config_with_pd_transfer(self): + ascend_config = AscendSchedulerConfig.initialize_from_config( + self.basic_scheduler_config, + AscendSchedulerConfig( + enable_pd_transfer=True, + decode_max_num_seqs=48, + max_num_batched_tokens=4096, + max_model_len=4096, + ), + ) + self.assertEqual(ascend_config.enable_pd_transfer, True) + self.assertEqual(ascend_config.decode_max_num_seqs, 48) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 1855c80..d723e0a 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -6,25 +6,21 @@ from unittest.mock import MagicMock, patch import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager from tests.ut.base import TestBase from vllm_ascend.core.scheduler import AscendScheduler -from vllm_ascend.utils import vllm_version_is - -if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): - from vllm.v1.outputs import DraftTokenIds -else: - DraftTokenIds = None EOS_TOKEN_ID = 50256 MODEL = "Qwen3-0.6B" @@ -44,7 +40,7 @@ def create_requests( max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, block_size: int = 3, - hash_fn=hash, + hash_fn=sha256, ): init_none_hash(hash_fn) prompt_logprobs = PROMPT_LOGPROBS @@ -54,25 +50,25 @@ def create_requests( prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - request = Request(request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_kwargs=None, - multi_modal_placeholders=None, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - pooling_params=None, - block_hasher=get_request_block_hasher( - block_size, hash_fn)) - else: - request = Request(request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - eos_token_id=EOS_TOKEN_ID, - pooling_params=None, - block_hasher=get_request_block_hasher( - block_size, hash_fn)) + mm_features = [] + if mm_positions is not None: + mm_position = mm_positions[i] + for j, position in enumerate(mm_position): + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) + request = Request(request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + eos_token_id=EOS_TOKEN_ID, + pooling_params=None, + mm_features=mm_features if mm_features else None, + block_hasher=get_request_block_hasher( + block_size, hash_fn)) requests.append(request) return requests @@ -85,25 +81,15 @@ def make_output(scheduler): } sampled_token_ids = [[1000]] * len(scheduler.running) logprobs = None - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - modelrunner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_id_to_index, - sampled_token_ids=sampled_token_ids, - spec_token_ids=None, - logprobs=logprobs, - prompt_logprobs_dict={}, - pooler_output=[], - ) - else: - modelrunner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_id_to_index, - sampled_token_ids=sampled_token_ids, - logprobs=logprobs, - prompt_logprobs_dict={}, - pooler_output=[], - ) + + modelrunner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs, + prompt_logprobs_dict={}, + pooler_output=[], + ) return modelrunner_output @@ -113,7 +99,7 @@ class TestAscendScheduler(TestBase): @patch("vllm.config.VllmConfig.__post_init__", MagicMock()) @patch('vllm.v1.core.sched.scheduler.compute_encoder_budget') def create_scheduler(self, mock_compute_encoder_budget): - mock_compute_encoder_budget.return_value = [10, 20] + mock_compute_encoder_budget.return_value = [100, 100] use_kv_connector = False block_size = 16 @@ -235,7 +221,7 @@ class TestAscendScheduler(TestBase): len(requests) - i - 1) def test_schedule(self): - '''Test scheduling. + '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs ''' scheduler = self.create_scheduler() @@ -260,6 +246,60 @@ class TestAscendScheduler(TestBase): for i, request in enumerate(requests): self.assertEqual(scheduler.running[i], request) + def test_schedule_multimodal_requests(self): + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + mm_positions = [[PlaceholderRange(offset=i, length=10)] + for i in range(10)] + requests = create_requests( + num_requests=10, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + self.assertEqual(len(output.scheduled_encoder_inputs), len(requests)) + for req_id, encoder_input in output.scheduled_encoder_inputs.items(): + assert len(encoder_input) == 1 + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + + def test_concurrent_partial_prefills_schedule(self): + '''Test concurrent partial prefills scheduling. + total requests = 10, every request has 10 token. + while set long_prefill_token_threshold = 1, scheduler can + only schedule max_long_partial_prefills long request. + ''' + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + scheduler.scheduler_config.max_long_partial_prefills = 2 + scheduler.scheduler_config.long_prefill_token_threshold = 1 + requests = create_requests(num_requests=10, num_tokens=20) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), + scheduler.scheduler_config.max_long_partial_prefills) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + def test_schedule_enable_prefix_caching(self): '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs @@ -304,69 +344,34 @@ class TestAscendScheduler(TestBase): scheduler.running.append(req) req.status = RequestStatus.RUNNING - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[EOS_TOKEN_ID], [ - 10, 11 - ]], # First request hits EOS, second continues - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - else: - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[EOS_TOKEN_ID], [ - 10, 11 - ]], # First request hits EOS, second continues - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[EOS_TOKEN_ID], [10, 11] + ], # First request hits EOS, second continues + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -391,67 +396,35 @@ class TestAscendScheduler(TestBase): scheduler.running.append(req) req.status = RequestStatus.RUNNING - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 42, 12], - [13, 14]], # First request hits stop token - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - else: - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 42, 12], - [13, 14]], # First request hits stop token - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: + [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -475,67 +448,35 @@ class TestAscendScheduler(TestBase): scheduler.running.append(req) req.status = RequestStatus.RUNNING - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 11, 12], - [13]], # First request exceeds max_tokens - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - else: - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 11, 12], - [13]], # First request exceeds max_tokens - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: + [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped due to length @@ -556,52 +497,27 @@ class TestAscendScheduler(TestBase): scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={requests[0].request_id: 3}, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[requests[0].request_id], - req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - else: - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={requests[0].request_id: 3}, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None) - model_output = ModelRunnerOutput( - req_ids=[requests[0].request_id], - req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -652,23 +568,13 @@ class TestAscendScheduler(TestBase): 512) # Model output of the first request. - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - model_runner_output = ModelRunnerOutput( - req_ids=[requests[0].request_id], - req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - else: - model_runner_output = ModelRunnerOutput( - req_ids=[requests[0].request_id], - req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + model_runner_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output0, model_runner_output) @@ -678,23 +584,13 @@ class TestAscendScheduler(TestBase): # request is still running. scheduler.schedule() # Model output of the second request. - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - model_runner_output = ModelRunnerOutput( - req_ids=[requests[1].request_id], - req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[0]], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - else: - model_runner_output = ModelRunnerOutput( - req_ids=[requests[1].request_id], - req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + model_runner_output = ModelRunnerOutput( + req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output1, model_runner_output) @@ -746,29 +642,19 @@ class TestAscendScheduler(TestBase): req_id = requests[i].request_id self.assertEqual(output.num_scheduled_tokens[req_id], 1) self.assertNotIn(req_id, output.scheduled_spec_decode_tokens) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], - logprobs=None, - prompt_logprobs_dict={}, - spec_token_ids=spec_tokens, - pooler_output=[]) - else: - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - draft_token_ids = DraftTokenIds(req_ids, spec_tokens) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[0] for _ in range(len(requests))], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + draft_token_ids = DraftTokenIds(req_ids, spec_tokens) engine_core_outputs = scheduler.update_from_output( output, model_runner_output) - if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): - scheduler.update_draft_token_ids(draft_token_ids) + scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): running_req = scheduler.running[i] @@ -804,23 +690,14 @@ class TestAscendScheduler(TestBase): else: self.assertNotIn(req_id, output.scheduled_spec_decode_tokens) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=output_tokens, - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - else: - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=output_tokens, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=output_tokens, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) engine_core_outputs = scheduler.update_from_output( output, model_runner_output) @@ -896,3 +773,34 @@ class TestAscendScheduler(TestBase): # Confirm no memory leak. self.assert_scheduler_empty(scheduler) + + def test_scheduler_with_pd_transfer(self): + scheduler = self.create_scheduler() + scheduler.phase = "prefill" + requests = create_requests(num_requests=32) + for request in requests: + scheduler.add_request(request) + + # 1st iteration, move 16 requests from waiting to running for prefill + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + first_iter_prefilled_req_num = len(scheduler.running) + self.assertEqual(len(scheduler_output.scheduled_new_reqs), + scheduler.max_num_running_reqs) + self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(scheduler_output.finished_req_ids), 0) + + # 2nd iteration, move 16 prefilled requests to finished_prefill_reqs + # and move 16 requests from waiting to running for prefill + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + self.assertEqual(len(scheduler.finished_prefill_reqs), + first_iter_prefilled_req_num) + + # 3rd iteration, all requests prefilled, change scheduler phase to decode + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + self.assertEqual(scheduler.phase, "decode") diff --git a/vllm_ascend/lora/punica_wrapper/__init__.py b/tests/ut/distributed/test_determin_expert_map_all.py similarity index 100% rename from vllm_ascend/lora/punica_wrapper/__init__.py rename to tests/ut/distributed/test_determin_expert_map_all.py diff --git a/tests/ut/distributed/test_distributed_tensor_parallel.py b/tests/ut/distributed/test_distributed_tensor_parallel.py deleted file mode 100644 index 48a88fa..0000000 --- a/tests/ut/distributed/test_distributed_tensor_parallel.py +++ /dev/null @@ -1,139 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -import importlib - -import pytest -import torch -from pytest_mock import MockerFixture - -from tests.ut.base import PytestBase -from vllm_ascend.distributed.tensor_parallel import ( - _gather_along_first_dim, _gather_along_last_dim, - _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, - all_to_all_hp2sp, all_to_all_sp2hp) - - -class TestDistributedCommunication(PytestBase): - - @pytest.fixture(autouse=True) - def context(self, mocker: MockerFixture): - mocker.patch("torch.npu.current_device", return_value="cpu") - mocker.patch("torch.distributed.get_world_size", return_value=4) - - mocker.patch("torch.distributed.get_rank", return_value=0) - - @pytest.mark.parametrize("world_size, test_tensor, expected", - [(1, torch.randn(8, 16), (8, 16)), - (4, torch.randn(8, 16), (32, 16))]) - def test_gather_along_first_dim(self, test_tensor, expected, world_size, - mocker: MockerFixture): - """test _gather_along_first_dim""" - mocker.patch("torch.distributed.get_world_size", - return_value=world_size) - - result = _gather_along_first_dim(test_tensor, mocker.MagicMock()) - - assert result.shape == expected - - @pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [ - (torch.randn(8, 16), [5, 10, 15, 2], (32, 16)), - ]) - def test_gather_along_first_dim_unequal_split(self, test_tensor, expected, - output_split_sizes, - mocker: MockerFixture): - """test _gather_along_first_dim""" - - result = _gather_along_first_dim(test_tensor, mocker.MagicMock(), - output_split_sizes) - - assert result.shape == expected - - @pytest.mark.parametrize("world_size, test_tensor, expected", - [(1, torch.randn(8, 16, 32), (8, 16, 32)), - (4, torch.randn(8, 16, 32), (8, 16, 32 * 4))]) - def test_gather_along_last_dim(self, test_tensor, expected, world_size, - mocker: MockerFixture): - """test _gather_along_last_dim""" - mocker.patch("torch.distributed.get_world_size", - return_value=world_size) - - result = _gather_along_last_dim(test_tensor, mocker.MagicMock()) - - assert result.shape == expected - - @pytest.mark.parametrize("input_shape,expected_shape", [ - ((32, 16), (8, 16)), - ((40, 10), (10, 10)), - ]) - def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = _reduce_scatter_along_first_dim(input_tensor, - mocker.MagicMock()) - assert result.shape == expected_shape - - @pytest.mark.parametrize("input_shape,expected_shape", [ - ((8, 16, 32), (8, 16, 8)), - ]) - def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = _reduce_scatter_along_last_dim(input_tensor, - mocker.MagicMock()) - assert result.shape == expected_shape - - @pytest.mark.parametrize("func,input_shape,expected_shape", [ - ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), - (8, 16, 128)), - ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), - ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), - (8, 16, 8)), - ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), - ]) - def test_wrapper_functions(self, func, input_shape, expected_shape, - mocker: MockerFixture): - """test wrapper funcs""" - mod = importlib.import_module( - 'vllm_ascend.distributed.tensor_parallel') - globals = mod.__dict__ - test_func = globals[func] - input_tensor = torch.randn(*input_shape) - result = test_func(input_tensor, mocker.MagicMock()) - assert result.shape == expected_shape - - @pytest.mark.parametrize( - "input_shape,output_shape", - [ - ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] - ]) - def test_all_to_all_sp2hp(self, input_shape, output_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = all_to_all_sp2hp(input_tensor, mocker.MagicMock()) - assert result.shape == output_shape - - @pytest.mark.parametrize( - "input_shape,output_shape", - [ - ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] - ]) - def test_all_to_all_hp2sp(self, input_shape, output_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = all_to_all_hp2sp(input_tensor, mocker.MagicMock()) - assert result.shape == output_shape diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index afc22c8..6b52b7b 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -4,8 +4,8 @@ import pytest from vllm.config import ParallelConfig from vllm_ascend.distributed.parallel_state import ( - _LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group, - get_mc2_group, init_ascend_model_parallel) + _LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group, + get_mc2_group, get_otp_group, init_ascend_model_parallel) @pytest.fixture @@ -29,16 +29,20 @@ def mock_distributed(): def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config = MagicMock() mock_ascend_config.lmhead_tensor_parallel_size = 2 + mock_ascend_config.oproj_tensor_parallel_size = 2 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config): init_ascend_model_parallel(parallel_config) mc2_group = get_mc2_group() - assert mc2_group is not None lmheadtp_group = get_lmhead_tp_group() + otp_group = get_otp_group() + assert mc2_group is not None + assert otp_group is not None assert lmheadtp_group is not None destroy_ascend_model_parallel() assert _MC2 is None assert _LMTP is None + assert _OTP is None diff --git a/tests/ut/eplb/adaptor/test_abstract_adaptor.py b/tests/ut/eplb/adaptor/test_abstract_adaptor.py new file mode 100644 index 0000000..a3d93ca --- /dev/null +++ b/tests/ut/eplb/adaptor/test_abstract_adaptor.py @@ -0,0 +1,73 @@ +import pytest + +from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor + + +class DummyAdaptor(EplbAdaptor): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.args = kwargs + + def get_rank_expert_workload(self): + return "workload" + + def get_init_expert_map(self, num_moe_layers): + return {"layers": num_moe_layers} + + def do_update_expert_map(self, layer_id, updated_expert_map): + return {"layer_id": layer_id, "map": updated_expert_map} + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, + buffer_tensor_id): + return { + "layer_id": layer_id, + "replace": local_expert_to_replace, + "buffer": buffer_tensor_id, + } + + +def test_base_class_methods_raise(): + adaptor = EplbAdaptor() + with pytest.raises(NotImplementedError): + adaptor.get_rank_expert_workload() + with pytest.raises(NotImplementedError): + adaptor.get_init_expert_map(1) + with pytest.raises(NotImplementedError): + adaptor.do_update_expert_map(1, {}) + with pytest.raises(NotImplementedError): + adaptor.do_update_expert_weight(1, "x", "y") + + +def test_dummy_adaptor_init_and_args(): + adaptor = DummyAdaptor(test_arg=123) + assert adaptor.args["test_arg"] == 123 + + +def test_get_rank_expert_workload(): + adaptor = DummyAdaptor() + result = adaptor.get_rank_expert_workload() + assert result == "workload" + + +def test_get_init_expert_map(): + adaptor = DummyAdaptor() + result = adaptor.get_init_expert_map(5) + assert isinstance(result, dict) + assert result["layers"] == 5 + + +def test_do_update_expert_map(): + adaptor = DummyAdaptor() + updated = {"expert": 1} + result = adaptor.do_update_expert_map(2, updated) + assert result["layer_id"] == 2 + assert result["map"] == updated + + +def test_do_update_expert_weight(): + adaptor = DummyAdaptor() + result = adaptor.do_update_expert_weight(1, "expertA", "bufferX") + assert result["layer_id"] == 1 + assert result["replace"] == "expertA" + assert result["buffer"] == "bufferX" diff --git a/tests/ut/eplb/core/policy/test_policy_abstract.py b/tests/ut/eplb/core/policy/test_policy_abstract.py new file mode 100644 index 0000000..26eb28b --- /dev/null +++ b/tests/ut/eplb/core/policy/test_policy_abstract.py @@ -0,0 +1,31 @@ +# test_policy_abstract.py +from vllm_ascend.eplb.core.policy.policy_abstract import (DynamicConfig, + EplbPolicy) + + +class DummyPolicy(EplbPolicy): + + def rebalance_experts(self, current_expert_table, expert_workload): + return 1, current_expert_table + + +def test_dynamic_config_attributes(): + config = DynamicConfig() + assert config.placement_policy is None + assert config.max_transferred_expert_per_layer == 100 + assert config.ep_worldsize == 64 + assert config.num_die_per_host == 8 + + +def test_eplb_policy_init_and_method(): + config = DynamicConfig() + policy = DummyPolicy(config) + + assert policy.config == config + + expert_table = [[0, 1, 2]] + workload = [10] + res, new_table = policy.rebalance_experts(expert_table, workload) + + assert res == 1 + assert new_table == expert_table diff --git a/tests/ut/eplb/core/policy/test_policy_dynamic_ep.py b/tests/ut/eplb/core/policy/test_policy_dynamic_ep.py new file mode 100644 index 0000000..f432d9b --- /dev/null +++ b/tests/ut/eplb/core/policy/test_policy_dynamic_ep.py @@ -0,0 +1,98 @@ +from unittest.mock import patch + +import numpy as np +import pytest + +from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb + + +class TestDynamicEplb: + + def test_add_redundant_basic(self): + current_expert_table = np.array([[[0, 1], [1, 0]]]) + expert_workload = np.array([[[2, 3], [4, 1]]]) + num_original_expert = 2 + result = DynamicEplb.add_redundant(current_expert_table, + expert_workload, + num_original_expert) + expected = np.array([[2 + 1, 3 + 4]]) + assert np.array_equal(result, expected) + + def test_get_redundant_num(self): + counts = np.array([2, 1, 3]) + assert DynamicEplb.get_redundant_num(3, counts) == 3 + + def test_calculate_max_heat_per_layer(self): + workload_table = np.array([[[1, 2], [3, 4]], [[2, 2], [1, 1]]]) + max_heat = DynamicEplb.calculate_max_heat_per_layer(workload_table, 2) + assert max_heat == [7, 4] + + def test_constraint_expert_local_exchange(self): + current = [[[0, 1], [2, 3]]] + global_dep = [[[1, 0], [3, 2]]] + new_dep = DynamicEplb.constraint_expert_local_exchange( + current, global_dep) + assert new_dep == [[[0, 1], [2, 3]]] + + def test_compute_balanced_pack_redundancy_normal(self): + origin_weights = [(0, 10), (1, 20)] + result, boxes = DynamicEplb.compute_balanced_pack_redundancy( + origin_weights, 2, 1) + assert isinstance(result, list) and len(result) == 2 + + def test_compute_balanced_pack_redundancy_card0(self): + origin_weights = [(0, 10)] + with pytest.raises(RuntimeError): + DynamicEplb.compute_balanced_pack_redundancy(origin_weights, 0, 0) + + def test_compute_balanced_pack_normal(self): + origin_weights = np.array([(0, 10), (1, 20)], dtype=object) + result, boxes = DynamicEplb.compute_balanced_pack(origin_weights, 2) + assert isinstance(result, list) and len(result) == 2 + + def test_compute_balanced_pack_card0(self): + origin_weights = np.array([(0, 10)], dtype=object) + with pytest.raises(RuntimeError): + DynamicEplb.compute_balanced_pack(origin_weights, 0) + + def test_original_compute_balanced_pack_redundancy(self): + origin_weights = [(0, 5), (1, 10)] + result, boxes = DynamicEplb.original_compute_balanced_pack_redundancy( + origin_weights, 2, 1) + assert isinstance(result, list) and len(result) == 2 + + def test_rebalance_experts_normal(self): + expert_table = np.array([[[0, 1], [1, 0]]]) + workload = np.array([[[2, 3], [4, 1]]]) + policy = DynamicEplb(config=None) + change, priority, new_dep = policy.rebalance_experts( + expert_table, workload) + assert change in [0, 1] + assert isinstance(priority, np.ndarray) + assert isinstance(new_dep, list) + assert np.array(new_dep).shape == expert_table.shape + + def test_rebalance_experts_exceptions(self): + policy = DynamicEplb(config=None) + + # case1: num_original_expert != expert_num + expert_table = np.array([[[0, 1], [1, 0]]]) + workload = np.array([[[2, 3], [4, 1]]]) + with patch.object(DynamicEplb, + 'add_redundant', + return_value=np.array([[1, 2, 3]])): + with pytest.raises(ValueError): + policy.rebalance_experts(expert_table, workload) + + # case2: num_npus <= 0 + expert_table_zero = np.array([[]]) # 1 layer, 0 NPU, 0 experts + workload_zero = np.array([[]]) + with pytest.raises(ValueError): + policy.rebalance_experts(expert_table_zero, workload_zero) + + # case3: num_npus < num_redundancy_expert + expert_table_small = np.array([[[0, 0]]]) # 1 layer, 1 NPU, 2 experts + workload_small = np.array([[[1, 1]]]) + with patch.object(DynamicEplb, 'get_redundant_num', return_value=2): + with pytest.raises(ValueError): + policy.rebalance_experts(expert_table_small, workload_small) diff --git a/tests/ut/eplb/core/policy/test_policy_dynamic_ep_v2.py b/tests/ut/eplb/core/policy/test_policy_dynamic_ep_v2.py new file mode 100644 index 0000000..eddd18c --- /dev/null +++ b/tests/ut/eplb/core/policy/test_policy_dynamic_ep_v2.py @@ -0,0 +1,99 @@ +from typing import Dict, Set + +import numpy as np +import pytest + +from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import (DynamicConfig, + DynamicEplbV2) + + +@pytest.fixture +def config(): + return DynamicConfig() + + +@pytest.fixture +def policy(config): + return DynamicEplbV2(config) + + +def test_safe_operations(policy): + # safe_divide + assert policy.safe_divide(10, 2) == 5 + assert policy.safe_divide(1, 0) == 0 + + # safe_exact_divide + assert policy.safe_exact_divide(10, 3) == 3 + assert policy.safe_exact_divide(1, 0) == 0 + + # safe_mod + assert policy.safe_mod(10, 3) == 1 + assert policy.safe_mod(1, 0) == 0 + + +def test_add_redundant(): + workload = np.array([[[1, 2], [3, 4]]]) + placement = np.array([[[0, 1], [0, 1]]]) + result = DynamicEplbV2.add_redundant(placement, workload, 2) + assert result.shape == (1, 2) + assert np.all(result[0] == [4, 6]) # 0:1+3, 1:2+4 + + +def test_get_redundant_num(): + counts = np.array([1, 2, 1]) + assert DynamicEplbV2.get_redundant_num(3, counts) == 1 # sum(counts-1) + + +def test_calculate_max_heat_per_layer(): + workload = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + result = DynamicEplbV2.calculate_max_heat_per_layer(workload, 2) + assert result == [7, 15] + + +def test_calculate_initial_imbalance(policy): + deployment = np.array([[[0, 1], [0, 1]]]) + workloads = np.array([[1, 1]]) + result = policy.calculate_initial_imbalance(deployment, workloads) + assert isinstance(result, list) + assert len(result) == 1 + + +def test_compute_redundant_assignments(policy): + base_experts = [(0, 10), (1, 5)] + redundant, sorted_weights = policy.compute_redundant_assignments( + base_experts, num_redundant_experts=2, num_experts=2) + assert len(redundant) == 2 + assert len(sorted_weights) == 2 + + +def test_prepare_expert_list(): + base_experts = [(0, 10), (1, 5)] + redundant_assignments = [[2], []] + result = DynamicEplbV2.prepare_expert_list(base_experts, + redundant_assignments, 1) + assert isinstance(result, list) + assert len(result) == 1 + + +def test_non_redundant_expert_information(): + origin_deployment = np.array([[0, 1]]) + updated_weights = [(0, 10), (1, 5)] + rendun_pos: Dict[int, Set[int]] = {0: set()} + assignments, weights, loads, counts = DynamicEplbV2.non_redundant_expert_information( + origin_deployment, updated_weights, rendun_pos) + assert assignments[0] == [0, 1] + assert loads[0] == 15 + + +def test_recomputing_initial_weight(policy): + layer_workloads = [10, 5] + device_assignments = [[0, 1]] + cur_layer_workload, num_all_experts = policy.recomputing_initial_weight( + layer_workloads, device_assignments) + assert cur_layer_workload[0] == 10 + assert num_all_experts[0] == 1 + + +def test_safe_divide_zero_edge_case(policy): + assert policy.safe_divide(0, 1) == 0 + assert policy.safe_divide(0, 5) == 0 diff --git a/tests/ut/eplb/core/policy/test_policy_factor.py b/tests/ut/eplb/core/policy/test_policy_factor.py new file mode 100644 index 0000000..7894335 --- /dev/null +++ b/tests/ut/eplb/core/policy/test_policy_factor.py @@ -0,0 +1,23 @@ +import pytest + +from vllm_ascend.eplb.core.policy.policy_abstract import DynamicConfig +from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb +from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import DynamicEplbV2 +from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory +from vllm_ascend.eplb.core.policy.policy_random import RandomLoadBalance + + +@pytest.fixture +def dummy_config(): + return DynamicConfig() + + +@pytest.mark.parametrize("policy_type, expected_class", [ + (0, RandomLoadBalance), + (1, DynamicEplb), + (2, DynamicEplbV2), + (999, RandomLoadBalance), +]) +def test_generate_policy(policy_type, expected_class, dummy_config): + policy_instance = PolicyFactory.generate_policy(policy_type, dummy_config) + assert isinstance(policy_instance, expected_class) diff --git a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py new file mode 100644 index 0000000..8835ff5 --- /dev/null +++ b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py @@ -0,0 +1,122 @@ +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +import torch + +import vllm_ascend.eplb.core.eplb_device_transfer_loader as loader + + +@pytest.fixture +def mock_adaptor(): + adaptor = MagicMock() + + adaptor.expert_map_per_layer_cpu = { + 0: { + 10: torch.tensor(1), + 20: torch.tensor(0) + } + } + + adaptor.expert_param_per_layer = { + 0: { + 0: [[torch.tensor([1.0])]], + 1: [[torch.tensor([2.0])]] + } + } + + adaptor.buffer_tensor_list = [[[torch.tensor([3.0])], + [torch.tensor([4.0])]]] + return adaptor + + +def test_generate_task_and_state_flow(mock_adaptor): + loader_obj = loader.D2DExpertWeightLoader() + loader_obj.set_adator(mock_adaptor) + + with patch("torch.distributed.P2POp") as mock_p2p, \ + patch("torch.distributed.isend", return_value="isend_op"), \ + patch("torch.distributed.irecv", return_value="irecv_op"): + + mock_p2p.side_effect = lambda op, tensor, rank: (op, tensor, rank) + + loader_obj.state = loader.ExpertWeightUpdateState.READY + loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)], + {20: torch.tensor(0)}, 0) + assert loader_obj.comm_op_list is None + loader_obj.state = loader.ExpertWeightUpdateState.WAITING + + loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0) + assert loader_obj.comm_op_list is None + + updated_map = {20: torch.tensor(0)} + loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)], + updated_map, 0) + assert loader_obj.state == loader.ExpertWeightUpdateState.READY + assert loader_obj.comm_op_list + assert loader_obj.recv_expert_list + + +def test_asyn_transfer_and_update(mock_adaptor): + loader_obj = loader.D2DExpertWeightLoader() + loader_obj.set_adator(mock_adaptor) + + loader_obj.comm_op_list = ["fake_op"] + loader_obj.state = loader.ExpertWeightUpdateState.READY + + reqs: list[MagicMock] = [] + + with patch("torch.distributed.batch_isend_irecv", + return_value=[MagicMock(), MagicMock()]): + loader_obj.asyn_expert_weight_transfer(reqs) + + assert loader_obj.state == loader.ExpertWeightUpdateState.TRANSFERRING + assert len(reqs) > 0 + + mock_req = MagicMock() + mock_req.wait.return_value = None + reqs = [mock_req] + + loader_obj.recv_expert_list = [(0, 0)] + loader_obj.updated_expert_map = {20: torch.tensor(0)} + loader_obj.updated_log2phy_map = {"dummy": 1} + loader_obj.layer_id = 0 + loader_obj.comm_op_list = ["op"] + + loader_obj.update_expert_map_and_weight(reqs) + + mock_adaptor.do_update_expert_map.assert_called_once() + mock_adaptor.do_update_log2phy_map.assert_called_once() + mock_adaptor.do_update_expert_weight.assert_called_once() + + assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING + assert loader_obj.recv_expert_list == [] + + +def test_set_log2phy_map(mock_adaptor): + loader_obj = loader.D2DExpertWeightLoader() + loader_obj.set_adator(mock_adaptor) + loader_obj.set_log2phy_map({"a": 1}) + assert loader_obj.updated_log2phy_map == {"a": 1} + + +def test_invalid_state_asyn_update(mock_adaptor): + loader_obj = loader.D2DExpertWeightLoader() + loader_obj.set_adator(mock_adaptor) + + loader_obj.state = loader.ExpertWeightUpdateState.WAITING + reqs: list[Any] = [] + loader_obj.asyn_expert_weight_transfer(reqs) + assert reqs == [] + + loader_obj.state = loader.ExpertWeightUpdateState.READY + loader_obj.update_expert_map_and_weight([]) + + assert not mock_adaptor.do_update_expert_map.called + + +def test_load_impl_not_implemented(mock_adaptor): + loader_obj = loader.D2DExpertWeightLoader() + loader_obj.set_adator(mock_adaptor) + with pytest.raises(NotImplementedError): + loader_obj.load_impl({}, {}) diff --git a/tests/ut/eplb/core/test_eplb_utils.py b/tests/ut/eplb/core/test_eplb_utils.py new file mode 100644 index 0000000..8a9761f --- /dev/null +++ b/tests/ut/eplb/core/test_eplb_utils.py @@ -0,0 +1,79 @@ +import random + +import torch + +from vllm_ascend.eplb.core import eplb_utils + + +def test_determine_default_expert_map_single_world(): + count, expert_map = eplb_utils.determine_default_expert_map( + global_expert_num=4, + world_size=1, + rank_id=0, + global_redundant_expert_num=0) + assert count == 4 + assert torch.equal(expert_map, torch.arange(4, dtype=torch.int32)) + + +def test_determine_default_expert_map_multiple_worlds_no_redundant(): + count, expert_map = eplb_utils.determine_default_expert_map( + global_expert_num=8, + world_size=2, + rank_id=0, + global_redundant_expert_num=0) + + assert count == 4 + assert torch.all(expert_map[:4] >= 0) + assert torch.all(expert_map[4:] == -1) + + +def test_determine_default_expert_map_multiple_worlds_with_redundant(): + count, expert_map = eplb_utils.determine_default_expert_map( + global_expert_num=5, + world_size=2, + rank_id=0, + global_redundant_expert_num=1) + + assert count == 3 + assert torch.all(expert_map[0:3] >= 0) + + +def test_generate_log2phy_map_single_rank_holding(): + + expert_map = torch.tensor([[0, -1], [-1, 0]], dtype=torch.int32) + log2phy_map = eplb_utils.generate_log2phy_map(expert_map) + + assert torch.all(log2phy_map[:, 0] == log2phy_map[0, 0]) + assert torch.all(log2phy_map[:, 1] == log2phy_map[1, 1]) + + +def test_generate_log2phy_map_multiple_rank_holding(monkeypatch): + + expert_map = torch.tensor([[0], [0]], dtype=torch.int32) + + monkeypatch.setattr(random, "choice", lambda x: x[0]) + + log2phy_map = eplb_utils.generate_log2phy_map(expert_map) + + assert log2phy_map.shape == (2, 1) + assert (log2phy_map >= 0).all() + + +def test_determine_default_log2phy_map_world_size_1(): + log2phy = eplb_utils.determine_default_log2phy_map( + global_expert_num=3, + world_size=1, + rank_id=0, + global_redundant_expert_num=0) + assert log2phy.shape == (3, ) + assert (log2phy >= 0).all() + + +def test_determine_default_log2phy_map_world_size_multiple(): + log2phy = eplb_utils.determine_default_log2phy_map( + global_expert_num=6, + world_size=2, + rank_id=1, + global_redundant_expert_num=1) + assert log2phy.shape == (6, ) + assert (log2phy >= 0).all() diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 0b2782d..2ea23bc 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -7,6 +7,7 @@ import time import types import unittest from collections import defaultdict, deque +from typing import OrderedDict from unittest.mock import MagicMock, patch import msgspec @@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase): tracker = KVCacheTaskTracker() self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) self.assertIsInstance(tracker.finished_requests, set) - self.assertIsInstance(tracker.delayed_free_requests, deque) + self.assertIsInstance(tracker.delayed_free_requests, OrderedDict) class TestGetAndClearFinishedSingleRequests(unittest.TestCase): @@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase): def test_update_done_task_count(self): self.assertEqual(len(self.tracker.finished_requests), 0) self.assertEqual(len(self.tracker.delayed_free_requests), 0) + self.assertEqual(len(self.tracker.record_finished_requests), 0) current_time = time.time() self.tracker.add_delayed_request("req_1", current_time) result = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests self.assertEqual(len(result), 1) - self.assertEqual(result[0], ("req_1", current_time)) + self.assertEqual(result["req_1"], current_time) + self.assertEqual(len(result_record), 0) self.tracker.update_done_task_count("req_1") result_finished = self.tracker.finished_requests result_delayed = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests self.assertEqual(result_finished, {"req_1"}) self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_record), 0) + + self.tracker.update_done_task_count("req_2") + result_finished = self.tracker.finished_requests + result_delayed = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests + self.assertEqual(result_finished, {"req_1", "req_2"}) + self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_record), 1) + self.assertEqual(result_record, {"req_2"}) + + def test_updtate_add_delayed_request(self) -> None: + self.tracker.update_done_task_count("req2") + result_start_record = self.tracker.record_finished_requests + self.assertEqual(len(result_start_record), 1) + self.tracker.add_delayed_request("req2", time.time()) + result_delayed = self.tracker.delayed_free_requests + result_end_record = self.tracker.record_finished_requests + self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_end_record), 0) def test_retrieve_expired_requests(self): current_time = time.time() @@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase): }) result_delay = self.tracker.delayed_free_requests self.assertEqual(len(result_delay), 1) - self.assertEqual(result_delay[0], ("req_2", current_time)) + self.assertIn("req_2", result_delay) def test_duplicate_task_update(self): self.tracker.update_done_task_count("req1") @@ -961,6 +986,46 @@ class TestMooncakeConnectorWorker(unittest.TestCase): for p in self.patches: p.stop() # type: ignore + def test_worker_use_ascend_direct(self): + test_case = [True, False] + + for use_ascend_direct in test_case: + with self.subTest(use_ascend_direct=use_ascend_direct): + config = MagicMock() + config.kv_transfer_config = MagicMock() + config.kv_transfer_config.get_from_extra_config.side_effect = ( + lambda k, d: { + "prefill": { + "tp_size": 2, + "dp_size": 1 + }, + "decode": { + "tp_size": 2, + "dp_size": 1 + }, + "use_ascend_direct": use_ascend_direct, + }.get(k, d)) + + config.parallel_config = MagicMock() + config.parallel_config.tensor_parallel_size = 2 + config.parallel_config.data_parallel_rank_local = 0 + config.parallel_config.data_parallel_size_local = 1 + config.kv_transfer_config.kv_port = 8000 + config.kv_transfer_config.kv_role = 'worker' + + with patch( + "vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank", + return_value=0): + with patch( + "vllm_ascend.distributed.mooncake_connector.get_tp_group", + return_value=None): + with patch( + "vllm_ascend.distributed.mooncake_connector.get_ip", + return_value="127.0.0.1"): + worker = MooncakeConnectorWorker( + config, self.engine_id) + self.assertIsNotNone(worker) + def test_register_kv_caches_producer(self): worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 13711e7..d1bf01f 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -10,6 +10,7 @@ import torch from vllm import SamplingParams from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.scheduler import Scheduler @@ -19,8 +20,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager -from vllm_ascend.utils import vllm_version_is - EOS_TOKEN_ID = 50256 os.environ["VLLM_USE_V1"] = "1" @@ -131,10 +130,10 @@ def create_request( """Make dummy request for testing.""" global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(sha256) _none_hash_initialized = True - block_hasher = get_request_block_hasher(block_size, hash) + block_hasher = get_request_block_hasher(block_size, sha256) kv_transfer_params: Optional[dict[str, Any]] = None @@ -160,27 +159,14 @@ def create_request( else: prompt_token_ids = [i * request_id for i in range(num_tokens)] - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - req = Request( - request_id=f"id-{request_id}", - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - multi_modal_kwargs=None, - multi_modal_placeholders=None, - multi_modal_hashes=None, - pooling_params=[], - eos_token_id=EOS_TOKEN_ID, - block_hasher=block_hasher, - ) - else: - req = Request( - request_id=f"id-{request_id}", - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - pooling_params=[], - eos_token_id=EOS_TOKEN_ID, - block_hasher=block_hasher, - ) + req = Request( + request_id=f"id-{request_id}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=[], + eos_token_id=EOS_TOKEN_ID, + block_hasher=block_hasher, + ) req.kv_transfer_params = kv_transfer_params return req @@ -208,26 +194,15 @@ def create_model_runner_output( kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) extra_args = {"kv_connector_output": kv_connector_output} - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_id_to_index, - sampled_token_ids=sampled_token_ids, - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - **extra_args, - ) - else: - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_id_to_index, - sampled_token_ids=sampled_token_ids, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - **extra_args, - ) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + **extra_args, + ) return model_runner_output diff --git a/tests/ut/models/conftest.py b/tests/ut/models/conftest.py new file mode 100644 index 0000000..d929943 --- /dev/null +++ b/tests/ut/models/conftest.py @@ -0,0 +1,114 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest +import torch +from transformers import PretrainedConfig +from vllm.config import CacheConfig, EPLBConfig, ParallelConfig +from vllm.distributed.parallel_state import GroupCoordinator + + +@pytest.fixture +def base_config(): + config = PretrainedConfig( + hidden_size=128, + num_attention_heads=8, + num_hidden_layers=2, + intermediate_size=256, + hidden_act="silu", + rms_norm_eps=1e-6, + rope_theta=10000.0, + max_position_embeddings=2048, + n_routed_experts=4, + n_shared_experts=1, + moe_intermediate_size=256, + num_experts_per_tok=2, + routed_scaling_factor=1.0, + first_k_dense_replace=0, + moe_layer_freq=1, + kv_lora_rank=16, + qk_nope_head_dim=16, + qk_rope_head_dim=16, + v_head_dim=32, + topk_method="noaux_tc", + scoring_func="softmax", + norm_topk_prob=True, + n_group=1, + topk_group=1, + vocab_size=10000, + ) + return config + + +@pytest.fixture +def vllm_config(base_config): + model_config = SimpleNamespace( + hf_config=base_config, + tensor_parallel_size=1, + dtype=torch.float32, + use_mla=True, + quant_config=None, + max_model_len=2048, + ) + parallel_config = MagicMock(spec=ParallelConfig) + eplb_config = MagicMock(spec=EPLBConfig) + eplb_config.num_redundant_experts = 0 + parallel_config.eplb_config = eplb_config + + cache_config = CacheConfig() + vllm_config = Mock() + vllm_config.model_config = model_config + vllm_config.cache_config = cache_config + vllm_config.quant_config = None + vllm_config.parallel_config = parallel_config + return vllm_config + + +@pytest.fixture +def mock_distributed(): + tp_group = Mock(spec=GroupCoordinator) + tp_group.rank_in_group = 0 + tp_group.world_size = 1 + tp_group.device_group = Mock() + + dp_group = Mock(spec=GroupCoordinator) + dp_group.rank_in_group = 0 + dp_group.world_size = 1 + + ep_group = Mock(spec=GroupCoordinator) + ep_group.rank_in_group = 0 + ep_group.world_size = 1 + ep_group.device_group = Mock() + ep_group.device_group.rank.return_value = 0 + ep_group.device_group.size.return_value = 1 + + pp_group = Mock(spec=GroupCoordinator) + pp_group.rank_in_group = 0 + pp_group.world_size = 1 + + mock_vllm_config = Mock() + mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) + mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) + + with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \ + patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \ + patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \ + patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \ + patch("vllm_ascend.models.deepseek_v2.get_pp_group", + return_value=Mock(is_first_rank=False, is_last_rank=False)), \ + patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ + patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \ + patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \ + patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, + _PP=pp_group), \ + patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \ + patch("torch.npu.current_device", return_value=0): + yield + + +@pytest.fixture +def mock_forward_context(): + forward_context = Mock(in_profile_run=False, with_prefill=False) + with patch("vllm_ascend.models.deepseek_v2.get_forward_context", + return_value=forward_context): + yield diff --git a/tests/ut/models/test_deepseek_mtp.py b/tests/ut/models/test_deepseek_mtp.py index 61fdf98..1dc7c9c 100644 --- a/tests/ut/models/test_deepseek_mtp.py +++ b/tests/ut/models/test_deepseek_mtp.py @@ -13,10 +13,13 @@ from vllm_ascend.models.deepseek_mtp import ( class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase): @pytest.fixture - def setup_mtp_layer(self, mocker: MockerFixture): + def setup_mtp_layer(self, mocker: MockerFixture, vllm_config: VllmConfig, + mock_distributed): config = PretrainedConfig(vocab_size=1000, hidden_size=768, rms_norm_eps=1e-5) + mocker.patch("vllm_ascend.models.deepseek_mtp.get_current_vllm_config", + return_value=vllm_config) mocker.patch( "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", return_value=None) @@ -29,15 +32,15 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase): "vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__", return_value=None) mocker_deepseek_v2_decode_layer = mocker.patch( - "vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__", + "vllm.model_executor.models.deepseek_v2.DeepseekV2DecoderLayer.__init__", return_value=None) mocker.patch( "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", return_value=None) - mocker.patch("vllm_ascend.utils.get_ascend_config", + mocker.patch("vllm_ascend.models.deepseek_v2.get_ascend_config", return_value=mocker.Mock()) - mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None) + mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "0", None) mocker_deepseek_v2_decode_layer.assert_called_once() return mtp_layer @@ -165,8 +168,6 @@ class TestCustomDeepSeekMTP(PytestBase): mocker.patch( "vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__", return_value=None) - mocker.patch("vllm.model_executor.layers.sampler.get_sampler", - return_value=None) mocker.patch( "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", return_value=None) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index df14a2a..693aea5 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -12,169 +12,23 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from types import SimpleNamespace from unittest.mock import Mock, patch import pytest import torch -from transformers import PretrainedConfig from vllm.config import CacheConfig -from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm_ascend.models.deepseek_v2 import ( - CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, - CustomDeepseekV2MLP, CustomDeepseekV2MoE, - CustomDeepseekV2RowParallelLinear, - CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead) +from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, + CustomDeepseekV2RowParallelLinear) -@pytest.fixture -def base_config(): - config = PretrainedConfig( - hidden_size=128, - num_attention_heads=8, - num_hidden_layers=2, - intermediate_size=256, - hidden_act="silu", - rms_norm_eps=1e-6, - rope_theta=10000.0, - max_position_embeddings=2048, - n_routed_experts=4, - n_shared_experts=1, - moe_intermediate_size=256, - num_experts_per_tok=2, - routed_scaling_factor=1.0, - first_k_dense_replace=0, - moe_layer_freq=1, - kv_lora_rank=16, - qk_nope_head_dim=16, - qk_rope_head_dim=16, - v_head_dim=32, - topk_method="noaux_tc", - scoring_func="softmax", - norm_topk_prob=True, - n_group=1, - topk_group=1, - vocab_size=10000, - ) - return config - - -@pytest.fixture -def vllm_config(base_config): - model_config = SimpleNamespace( - hf_config=base_config, - tensor_parallel_size=1, - dtype=torch.float32, - use_mla=False, - quant_config=None, - max_model_len=2048, - ) - - cache_config = CacheConfig() - vllm_config = Mock() - vllm_config.model_config = model_config - vllm_config.cache_config = cache_config - vllm_config.quant_config = None - return vllm_config - - -@pytest.fixture -def mock_distributed(): - tp_group = Mock(spec=GroupCoordinator) - tp_group.rank_in_group = 0 - tp_group.world_size = 1 - tp_group.device_group = Mock() - - dp_group = Mock(spec=GroupCoordinator) - dp_group.rank_in_group = 0 - dp_group.world_size = 1 - - ep_group = Mock(spec=GroupCoordinator) - ep_group.rank_in_group = 0 - ep_group.world_size = 1 - - pp_group = Mock(spec=GroupCoordinator) - pp_group.rank_in_group = 0 - pp_group.world_size = 1 - - mock_vllm_config = Mock() - mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) - mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) - - with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \ - patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \ - patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \ - patch("vllm_ascend.models.deepseek_v2.get_ep_group", return_value=ep_group), \ - patch("vllm_ascend.models.deepseek_v2.get_dp_group", return_value=dp_group), \ - patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \ - patch("vllm_ascend.models.deepseek_v2.get_pp_group", - return_value=Mock(is_first_rank=False, is_last_rank=False)), \ - patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ - patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, - _PP=pp_group), \ - patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \ - patch("torch.npu.current_device", return_value=0): - yield - - -@pytest.fixture -def mock_forward_context(): - forward_context = Mock(in_profile_run=False, with_prefill=False) - with patch("vllm_ascend.models.deepseek_v2.get_forward_context", - return_value=forward_context): - yield - - -def test_custom_deepseek_v2_silu_and_mul(): - torch.set_default_device("cpu") - - silu = CustomDeepseekV2SiluAndMul() - assert silu.weight_scale is None - - x = torch.randn(2, 4) - output = silu.forward_oot(x) - assert output.shape == (2, 2) - - weight_scale = Mock(return_value=torch.tensor(0.1)) - silu = CustomDeepseekV2SiluAndMul(weight_scale=weight_scale) - quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32) - dynamic_scale = torch.randn(2, 1) - with patch("torch_npu.npu_dequant_swiglu_quant", - return_value=torch.randn(2, 4)): - output = silu.forward_oot((quant_x, dynamic_scale)) - assert output.shape == (2, 4) - - -def test_custom_deepseek_v2_merged_replicated_linear(mock_distributed): - linear = CustomDeepseekV2MergedReplicatedLinear(input_size=128, - output_sizes=[64, 64], - bias=False, - quant_config=None) - assert linear.output_sizes == [64, 64] - - param = Mock() - param.data = torch.zeros(128, 128) - param.output_dim = 1 - param.is_gguf_weight = False - param.is_gguf_weight_type = False - loaded_weight = torch.randn(128, 64) - linear.weight_loader(param, loaded_weight, loaded_shard_id=0) - - with pytest.raises(AssertionError): - linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0) - - -@pytest.mark.parametrize("cls", [ - CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2RowParallelLinear -]) +@pytest.mark.parametrize("cls", [CustomDeepseekV2RowParallelLinear]) def test_row_parallel_linear(cls, mock_distributed): linear = cls(input_size=128, output_size=64, bias=False, quant_config=None) linear.quant_method = Mock() linear.quant_method.apply.return_value = torch.randn(2, 4, 64) - input_ = torch.randn(2, 4, 128) with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim", return_value=[torch.randn(2, 4, 64)]): @@ -187,52 +41,10 @@ def test_row_parallel_linear(cls, mock_distributed): assert output[0].shape == (2, 4, 64) -def test_custom_deepseek_v2_mlp(mock_distributed, base_config): - mlp = CustomDeepseekV2MLP(hidden_size=128, - intermediate_size=256, - hidden_act="silu", - quant_config=None) - assert isinstance(mlp.act_fn, CustomDeepseekV2SiluAndMul) - - x = torch.randn(2, 4, 128) - output = mlp(x) - assert output.shape == (2, 4, 128) - - with patch("vllm_ascend.models.deepseek_v2.QuantizationConfig" - ) as mock_quant_config: - mock_quant_config.name = "w8a8dynamic" - with pytest.raises(NotImplementedError): - CustomDeepseekV2MLP(hidden_size=128, - intermediate_size=256, - hidden_act="silu", - quant_config=mock_quant_config, - force_replicate=False) - with pytest.raises(ValueError): - CustomDeepseekV2MLP(hidden_size=128, - intermediate_size=256, - hidden_act="relu", - quant_config=None) - - -def test_custom_deepseek_v2_moe(mock_distributed, base_config, - mock_forward_context): - base_config.n_shared_experts = 1 - moe = CustomDeepseekV2MoE(config=base_config, - quant_config=None, - prefix="mlp") - assert moe.top_k == 2 - - x = torch.randn(2, 4, 128) - attn_metadata = Mock(num_prefills=1) - with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__", - return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))): - output = moe(x, attn_metadata) - assert output.shape == (2, 4, 128) - - +@patch("torch.ops.vllm.mla_forward") @patch("torch_npu.npu_rms_norm") -def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, - base_config): +def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, + mock_distributed, base_config): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) attn = CustomDeepseekV2MLAAttention(config=base_config, @@ -253,8 +65,8 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, with patch.object(attn.mla_attn, "__call__", return_value=torch.randn(2, 4, 128)): - with pytest.raises(AssertionError): - attn(positions, x) + attn(positions, x) + mock_mla_forward.assert_called_once() attn = CustomDeepseekV2MLAAttention(config=base_config, hidden_size=128, diff --git a/tests/ut/models/test_qwen2_5_vl.py b/tests/ut/models/test_qwen2_5_vl.py index 15367eb..06fb07d 100644 --- a/tests/ut/models/test_qwen2_5_vl.py +++ b/tests/ut/models/test_qwen2_5_vl.py @@ -286,6 +286,22 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase): "vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size", return_value=2, ) + mocker.patch( + "vllm_ascend.ops.linear.divide", + return_value=2, + ) + + mock_group = mocker.MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 2 + mocker.patch( + "vllm_ascend.ops.linear_op.get_tp_group", + return_value=mock_group, + ) + mocker.patch( + "vllm.distributed.parallel_state.get_tp_group", + return_value=mock_group, + ) vision_transformer = AscendQwen2_5_VisionTransformer( vision_config, @@ -341,6 +357,46 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase): cos_new, _ = vision_transformer.cal_cos_sin(self.input_data) assert cos_new.shape == (1, 32, 1, 2) + def test_pad_qkv_bias(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_bias(torch.rand((300))) + assert res.shape[0] == 384 + + def test_pad_qkv_weight(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_weight(torch.rand((300, 300))) + assert res.shape == (384, 300) + + def test_pad_proj_weight(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_proj_weight(torch.rand((300, 300))) + assert res.shape == (300, 384) + + def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1))) + assert res.shape == (384, 1) + + def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300))) + assert res.shape[0] == 384 + def test_forward(self, mocker: MockerFixture): vision_transformer = self.init_vision_transformer(mocker) mocker.patch("torch.nn.Module.__setattr__") diff --git a/tests/ut/models/test_qwen3_moe.py b/tests/ut/models/test_qwen3_moe.py index e882fe2..858b106 100644 --- a/tests/ut/models/test_qwen3_moe.py +++ b/tests/ut/models/test_qwen3_moe.py @@ -15,41 +15,11 @@ import math import unittest -import pytest import torch -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM -from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention -class TestCustomQwen3MoeForCausalLM: - - def test_class_inheritance(self): - assert issubclass(CustomQwen3MoeForCausalLM, Qwen3MoeForCausalLM) - - @pytest.mark.parametrize("key, expected", [ - ("qkv_proj", ["q_proj", "k_proj", "v_proj"]), - ("gate_up_proj", ["gate_proj", "up_proj"]), - ("experts", - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]), - ]) - def test_packed_modules_mapping(self, key, expected): - assert CustomQwen3MoeForCausalLM.packed_modules_mapping[ - key] == expected - - def test_packed_modules_mapping_structure(self): - expected_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": [ - "experts.0.gate_proj", "experts.0.up_proj", - "experts.0.down_proj" - ] - } - assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping - - class DummyRMSNorm: def __init__(self, dim: int, eps: float = 1e-6): diff --git a/tests/ut/ops/test_activation.py b/tests/ut/ops/test_activation.py index b90ccff..76bc55d 100644 --- a/tests/ut/ops/test_activation.py +++ b/tests/ut/ops/test_activation.py @@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor): @pytest.mark.parametrize("is_310p_return", [True, False]) @patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1) -def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor): +@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) +@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", + side_effect=lambda x: None) +def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj, + mock_maybe_wait_prefetch_done, mock_swiglu, + is_310p_return, dummy_tensor): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): layer = SiluAndMul() @@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor): else: expected_arg = dummy_tensor + # assert mock_maybe_prefetch_mlp_down_proj.call_count == 1 + mock_maybe_prefetch_mlp_down_proj.assert_called_once() + # assert mock_swiglu.call_count == 1 mock_swiglu.assert_called_once() + # assert mock_maybe_wait_prefetch_done.call_count == 1 + mock_maybe_wait_prefetch_done.assert_called_once() + actual_arg = mock_swiglu.call_args[0][0] assert torch.allclose( actual_arg, diff --git a/tests/ut/ops/test_comm_utils.py b/tests/ut/ops/test_comm_utils.py new file mode 100644 index 0000000..5b4071c --- /dev/null +++ b/tests/ut/ops/test_comm_utils.py @@ -0,0 +1,98 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import pytest +import torch +from pytest_mock import MockerFixture + +from tests.ut.base import PytestBase +from vllm_ascend.ops.moe.comm_utils import ( + _gather_along_first_dim, async_all_to_all, + gather_from_sequence_parallel_region) + + +class TestDistributedCommunication(PytestBase): + + @pytest.fixture(autouse=True) + def context(self, mocker: MockerFixture): + mocker.patch("torch.npu.current_device", return_value="cpu") + mocker.patch("torch.distributed.get_world_size", return_value=4) + + mocker.patch("torch.distributed.get_rank", return_value=0) + + @pytest.mark.parametrize( + "input_tensor, output_split_sizes, input_split_sizes", + [(torch.randn(8, 16), [2, 2, 2, 2], [2, 2, 2, 2]), + (torch.randn(16, 32), None, None)]) + def test_async_all_to_all(self, input_tensor, output_split_sizes, + input_split_sizes, mocker: MockerFixture): + """Test async_all_to_all""" + mock_group = mocker.MagicMock() + mocker.patch("torch.distributed.all_to_all_single", + return_value=mocker.MagicMock()) + + _, a2a_out, handle = async_all_to_all(input_tensor, output_split_sizes, + input_split_sizes, mock_group) + + # Check if the output tensor is created properly + if output_split_sizes is None: + assert a2a_out.shape == input_tensor.shape + else: + total_output_size = sum(output_split_sizes) + expected_shape = [total_output_size] + list( + input_tensor.size())[1:] + assert a2a_out.shape == torch.Size(expected_shape) + + # Ensure handle is returned from async operation + assert handle is not None + assert isinstance(handle, mocker.MagicMock) + + @pytest.mark.parametrize("world_size, test_tensor, expected", + [(1, torch.randn(8, 16), (8, 16)), + (4, torch.randn(8, 16), (32, 16))]) + def test_gather_along_first_dim(self, test_tensor, expected, world_size, + mocker: MockerFixture): + """Test _gather_along_first_dim""" + mocker.patch("torch.distributed.get_world_size", + return_value=world_size) + + result = _gather_along_first_dim(test_tensor, mocker.MagicMock()) + + assert result.shape == expected + + @pytest.mark.parametrize("input_tensor, output_split_sizes", + [(torch.randn(8, 16), None), + (torch.randn(8, 16), [2, 2, 2, 2])]) + def test_gather_from_sequence_parallel_region(self, input_tensor, + output_split_sizes, + mocker: MockerFixture): + """Test gather_from_sequence_parallel_region""" + mock_group = mocker.MagicMock() + + result = gather_from_sequence_parallel_region(input_tensor, mock_group, + output_split_sizes) + + # If output_split_sizes is not provided, result should have expanded first dimension by world size + if output_split_sizes is None: + expected_shape = [input_tensor.shape[0] * 4] + list( + input_tensor.shape[1:]) + assert result.shape == torch.Size(expected_shape) + else: + # If output_split_sizes is provided, result shape is dictated by sum of output_split_sizes + expected_shape = [sum(output_split_sizes)] + list( + input_tensor.shape[1:]) + assert result.shape == torch.Size(expected_shape) diff --git a/tests/ut/ops/test_common_fused_moe.py b/tests/ut/ops/test_common_fused_moe.py index 409a301..6153a4e 100644 --- a/tests/ut/ops/test_common_fused_moe.py +++ b/tests/ut/ops/test_common_fused_moe.py @@ -17,53 +17,40 @@ from unittest.mock import patch import torch from tests.ut.base import TestBase -from vllm_ascend.ops.common_fused_moe import fused_experts_moge +from vllm_ascend.ops.common_fused_moe import AscendFusedMoE -class TestFusedExpertsMoGE(TestBase): +class TestLoadWeight(TestBase): - def test_fused_experts_moge(self): - with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \ - patch('torch_npu.npu_swiglu') as mock_swiglu, \ - patch('vllm_ascend.utils.is_310p') as mock_is_310p: + def test_load_w13_transpose(self): + with patch.object(AscendFusedMoE, "__init__", + lambda self, *args, **kwargs: None): + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) - mock_is_310p.return_value = False + expert_data = torch.randn(128, 8) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) - mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [ - torch.randn(x[0].shape[0], weight[0].shape[1]) - ] + expert_data = torch.randn(8, 128) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) - mock_swiglu.side_effect = lambda x: x + expert_data = torch.randn(128, 8) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) - hidden_states = torch.randn(4, 128) - w1 = torch.randn(4, 256, 128) - w2 = torch.randn(4, 128, 128) - topk_weights = torch.rand(4, 1) - topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long) - top_k = 1 - global_num_experts = 4 + expert_data = torch.randn(8, 128) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) - moe_parallel_config = type( - 'MockConfig', (), { - 'ep_size': 1, - 'tp_size': 1, - 'dp_size': 1, - 'tp_rank': 0, - 'dp_rank': 0, - 'ep_rank': 0, - 'use_ep': True - })() + def test_load_w2_transpose(self): + with patch.object(AscendFusedMoE, "__init__", + lambda self, *args, **kwargs: None): + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) + expert_data = torch.randn(128, 4) + loaded_weight = torch.randn(128, 8) + moe._load_w2(expert_data, 1, loaded_weight, 0) - output = fused_experts_moge( - hidden_states=hidden_states, - w1=w1, - w2=w2, - moe_parallel_config=moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - global_num_experts=global_num_experts, - apply_router_weight_on_input=True, - ) - - self.assertEqual(output.shape, (4, 128)) + expert_data = torch.randn(4, 128) + loaded_weight = torch.randn(128, 8) + moe._load_w2(expert_data, 1, loaded_weight, 0) diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py new file mode 100644 index 0000000..ce7970c --- /dev/null +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -0,0 +1,289 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( + FusedMoEPrepareAndFinalizeWithAll2All, + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, + FusedMoEPrepareAndFinalizeWithNaiveMulticast) +from vllm_ascend.utils import vllm_version_is + + +class TestFusedMoEPrepareAndFinalize(unittest.TestCase): + + def setUp(self): + # Mock FusedMoEConfig + self.moe_config = MagicMock(spec=FusedMoEConfig) + self.moe_config.tp_group = MagicMock() + self.moe_config.tp_group.device_group = MagicMock() + self.moe_config.dp_size = 1 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + self.moe_config.dp_group = MagicMock() + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=1) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, + mock_tp_size): + mock_context = MagicMock() + mock_context.mc2_mask = torch.tensor([1, 0, 1]) + mock_context.padded_num_tokens = 4 + mock_get_forward_context.return_value = mock_context + + layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + h_out, r_out, mask = layer.prepare(hidden_states, router_logits) + + # Check padding and split + self.assertEqual(h_out.shape[0], 4) + self.assertEqual(r_out.shape[0], 4) + self.assertEqual(mask.tolist(), [1, 0, 1]) + + # Finalize + result = layer.finalize(h_out, reduce_results=False) + self.assertEqual(result.shape[0], 3) + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=2) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + @patch("torch.distributed.all_gather") + def test_mc2_tp_split_allgather(self, mock_all_gather, + mock_get_forward_context, mock_tp_rank, + mock_tp_size): + mock_context = MagicMock() + mock_context.mc2_mask = torch.tensor([1, 0, 1, 0]) + mock_context.padded_num_tokens = 4 + mock_get_forward_context.return_value = mock_context + + layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + hidden_states = torch.randn(4, 8) + router_logits = torch.randn(4, 2) + + h_out, r_out, mask = layer.prepare(hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) + + # With TP=2, should split into 2 parts + self.assertEqual(h_out.shape[0], 2) + + # Mock all_gather behavior + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + layer.split_hidden_states = [ + torch.zeros_like(h_out), + torch.zeros_like(h_out) + ] + final_result = layer.finalize(h_out, reduce_results=False) + + # Should concat back to original size + self.assertEqual(final_result.shape[0], 4) + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=1) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size): + layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + h_out, r_out, _ = layer.prepare(hidden_states, router_logits) + + # Pad to tp_size=1, so no change + self.assertEqual(h_out.shape[0], 3) + + result = layer.finalize(h_out, reduce_results=False) + self.assertEqual(result.shape[0], 3) + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=2) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + @patch("torch.distributed.all_gather") + def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank, + mock_tp_size): + layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + hidden_states = torch.randn(2, 8) + router_logits = torch.randn(2, 2) + + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) + + # Split due to TP=2 + self.assertEqual(h_out.shape[0], 1) + + # Mock all_gather + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + layer.split_hidden_states = [ + torch.zeros_like(h_out), + torch.zeros_like(h_out) + ] + final_result = layer.finalize(h_out, reduce_results=False) + + # Should concat back + self.assertEqual(final_result.shape[0], 2) + + @patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" + ) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + def test_allgather_prepare_finalize(self, mock_get_forward_context, + mock_tp_all_reduce, mock_get_dp_group): + # Mock forward context + mock_context = MagicMock() + mock_context.max_tokens_across_dp = 6 + mock_get_forward_context.return_value = mock_context + + # Create a proper mock for DP group with working all_gather + mock_dp_group = MagicMock() + + def mock_all_gather_func(tensor, dim): + # Simulate DP=2: repeat the tensor along the specified dimension + return torch.cat([tensor, tensor], dim=dim) + + mock_dp_group.all_gather = mock_all_gather_func + mock_get_dp_group.return_value = mock_dp_group + + self.moe_config.dp_size = 2 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + self.moe_config.dp_group = mock_dp_group + + layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + # Mock the gate function for rm_router_logits=False case + mock_gate = MagicMock() + mock_gate.return_value = (router_logits.repeat(2, 1), None) + + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + rm_router_logits=False, + gate=mock_gate) + + # After all-gather with DP=2, should double the batch size + self.assertEqual(h_out.shape[0], 12) + self.assertEqual(r_out.shape[0], 12) + + # Finalize with reduce_scatter + def mock_reduce_scatter_func(tensor, dim): + # Simulate reduce_scatter: take first half + return tensor[:3] + + mock_dp_group.reduce_scatter = mock_reduce_scatter_func + result = layer.finalize(h_out, reduce_results=False) + + self.assertEqual(result.shape[0], 3) + + # Test with TP all-reduce + mock_tp_all_reduce.return_value = result + result_with_tp = layer.finalize(h_out, reduce_results=True) + self.assertEqual(result_with_tp.shape[0], 3) + + @patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" + ) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + def test_naive_multicast_prepare_finalize(self, mock_get_forward_context, + mock_tp_all_reduce, + mock_get_dp_group): + # Mock forward context with DP metadata + mock_context = MagicMock() + if vllm_version_is("0.10.2"): + mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor( + [2, 5, 7]) + else: + mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor( + [2, 5, 7]) + mock_get_forward_context.return_value = mock_context + + # Setup DP group mock + mock_dp_group = MagicMock() + mock_dp_group.broadcast = MagicMock() + mock_dp_group.all_reduce = MagicMock() + mock_get_dp_group.return_value = mock_dp_group + + # Mock all_reduce to just return input (simulate sum) + def mock_all_reduce(tensor): + return tensor * 2 + + mock_dp_group.all_reduce.side_effect = mock_all_reduce + + # Setup config + self.moe_config.dp_size = 3 + self.moe_config.dp_rank = 1 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + + layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) + + # Local inputs + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + # Mock gate for router logits recomputation + mock_gate = MagicMock() + mock_gate.return_value = (torch.randn(7, 2), None) + + # Run prepare + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + rm_router_logits=False, + gate=mock_gate) + + # Should be global tensor: [7, 8] and [7, 2] + self.assertEqual(h_out.shape, (7, 8)) + self.assertEqual(r_out.shape, (7, 2)) + + # Run finalize + result = layer.finalize(h_out, reduce_results=False) + + # Should slice back to local: [3, 8] + self.assertEqual(result.shape, (3, 8)) + + # Test with reduce_results=True and TP/EP > 1 + mock_tp_all_reduce.return_value = result + result_with_tp = layer.finalize(h_out, reduce_results=True) + self.assertEqual(result_with_tp.shape, (3, 8)) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 6a51d1d..a5bdfe2 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -22,15 +22,13 @@ import torch_npu from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module from tests.ut.base import TestBase -from vllm_ascend.ascend_forward_context import (FusedMoEState, - _get_fused_moe_state) +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp -from vllm_ascend.utils import AscendSocVersion, adapt_patch +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp +from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is adapt_patch(True) @@ -58,122 +56,94 @@ def mock_npu_format_cast(weight_data, format): return weight_data +@pytest.fixture(autouse=True) +def setup_vllm_config_mock(mocker: MockerFixture): + mock_hf_config = MagicMock() + mock_hf_config.model_type = "llama" + + mock_model_config = MagicMock() + mock_model_config.hf_config = mock_hf_config + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config = mock_model_config + mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2) + mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) + mock_vllm_config.model_config.max_model_len = 2048 + + mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', + return_value=mock_vllm_config) + mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config', + return_value=mock_vllm_config) + + @pytest.fixture def mock_dist_env(mocker: MockerFixture): - mock_setup_token_dispatchers = MagicMock() - mock_token_dispatcher_with_allgather = MagicMock() - mock_token_dispatcher_with_all2allv = MagicMock() - mock_token_dispatcher_with_mc2 = MagicMock() + mock_moe_comm_method = MagicMock() - mock_dispatch_result_allgather = { - "hidden_states": torch.randn(16, 2), - "group_list": torch.tensor([8, 16], dtype=torch.int64), - "group_list_type": 0, - } - mock_combine_result_allgather = torch.randn(16, 2) + def mock_prepare(hidden_states, router_logits, **kwargs): + return hidden_states, router_logits - mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather - mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather + mock_moe_comm_method.prepare.side_effect = mock_prepare - mock_dispatch_result_all2allv = { - "hidden_states": torch.randn(16, 2), - "group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64), - "group_list_type": 1, - "dynamic_scale": None, - } - mock_combine_result_all2allv = torch.randn(16, 2) - mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv - mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv + mock_fused_experts_result = torch.randn(16, 2) + mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result - mock_dispatch_result_mc2 = { - "hidden_states": torch.randn(16, 2), - "group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64), - "group_list_type": 1, - "dynamic_scale": None, - "assist_info_for_combine": torch.randn(16, 2), - "ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32), - } - mock_combine_result_mc2 = torch.randn(16, 2) - mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2 - mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2 + def mock_finalize(hidden_states, **kwargs): + return hidden_states - captured_dispatchers = {} + mock_moe_comm_method.finalize.side_effect = mock_finalize - def capture_register(dispatcher_instance): - key = dispatcher_instance.__class__.__name__ - captured_dispatchers[key] = dispatcher_instance - if key == 'TokenDispatcherWithAllGather': - captured_dispatchers[key] = mock_token_dispatcher_with_allgather - elif key == 'TokenDispatcherWithAll2AllV': - captured_dispatchers[key] = mock_token_dispatcher_with_all2allv - elif key == 'TokenDispatcherWithMC2': - captured_dispatchers[key] = mock_token_dispatcher_with_mc2 - - mock_register_token_dispatcher_patcher = patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher', - side_effect=capture_register) - - mock_get_token_dispatcher_patcher = patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher', - side_effect=lambda name: captured_dispatchers.get(name)) - - default_mock_token_dispatcher = mock_token_dispatcher_with_allgather - - mock_forward_context_obj = MagicMock( - fused_moe_state=FusedMoEState.AllGather, - token_dispatcher=default_mock_token_dispatcher, - max_tokens_across_dp=10, - dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), - mc2_mask=torch.zeros(16, dtype=torch.bool), - padded_num_tokens=16, - with_quant=False) + if vllm_version_is("0.10.2"): + dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10]) + else: + dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) + mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method, + moe_comm_type=MoECommType.MC2, + max_tokens_across_dp=10, + dp_metadata=dp_metadata, + mc2_mask=torch.zeros( + 16, dtype=torch.bool), + padded_num_tokens=16, + with_quant=False) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('torch.distributed.all_gather'), \ - patch('torch.distributed.all_to_all_single'), \ - patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \ - patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_ascend_config', return_value=MagicMock( - torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False), + torchair_graph_config=MagicMock(enabled=False), + enable_multistream_moe=False, expert_map_path=None )), \ patch('vllm_ascend.ops.fused_moe.determine_expert_map', return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ patch('vllm_ascend.ops.fused_moe.get_forward_context', return_value=mock_forward_context_obj), \ - patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', - return_value=MagicMock( - parallel_config=MagicMock(tensor_parallel_size=2), - scheduler_config=MagicMock(max_num_seqs=4), - model_config=MagicMock(max_model_len=2048) - )), \ + patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', + return_value=mock_forward_context_obj), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ - patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \ - patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context', - return_value=mock_forward_context_obj): + patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', + return_value=mock_forward_context_obj), \ + patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', + return_value=None), \ + patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', + return_value=None), \ + patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', + return_value=None): yield { 'mock_forward_context_obj': mock_forward_context_obj, - 'mock_token_dispatcher_with_allgather': - mock_token_dispatcher_with_allgather, - 'mock_token_dispatcher_with_all2allv': - mock_token_dispatcher_with_all2allv, - 'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2, + 'mock_moe_comm_method': mock_moe_comm_method, } - mock_register_token_dispatcher_patcher.stop() - mock_get_token_dispatcher_patcher.stop() - @pytest.fixture def mock_moe_env(mocker: MockerFixture): @@ -235,6 +205,8 @@ def default_moe_config(): def moe_method(mock_dist_env): moe = MagicMock() moe.moe_parallel_config.return_value = MagicMock(ep_size=4) + moe.moe_parallel_config.use_ep = False + moe.moe_parallel_config.dp_size = 1 return AscendUnquantizedFusedMoEMethod(moe) @@ -280,6 +252,9 @@ class MockFusedMoEMethod(FusedMoEMethodBase): expert_weights: torch.Tensor) -> torch.Tensor: pass + def get_fused_moe_quant_config(self, layer: torch.nn.Module): + pass + class TestAscendFusedMoe: @@ -339,9 +314,7 @@ class TestAscendFusedMoe: moe.moe_parallel_config.ep_size = 1 moe.quant_method = MockQuantMethod(shared_experts, num_tokens) - forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens, - dtype=torch.bool), - padded_num_tokens=num_tokens) + forward_context = mock_dist_env['mock_forward_context_obj'] with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): output = moe.forward(inputs, @@ -395,25 +368,10 @@ class TestAscendUnquantizedFusedMoEMethod: [[256, 4], [128, 1], [128, 1], [128, 4]]) def test_apply_without_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - global_num_experts, ep_size = others_param is_prefill = False - is_deepseek_v3_r1 = global_num_experts == 256 - if ep_size == 1: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_allgather'] - elif ep_size < 16: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_all2allv'] - else: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_mc2'] - - forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( - ep_size, is_prefill, is_deepseek_v3_r1), - with_quant=False, - token_dispatcher=selected_token_dispatcher) + forward_context = mock_dist_env['mock_forward_context_obj'] with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): @@ -439,35 +397,22 @@ class TestAscendUnquantizedFusedMoEMethod: global_num_experts=global_num_experts, is_prefill=is_prefill) - expected_shape = (16, 2) + mock_moe_comm_method = mock_dist_env['mock_moe_comm_method'] + mock_moe_comm_method.fused_experts.assert_called_once() + expected_shape = (16, 2) assert result.shape == expected_shape @pytest.mark.parametrize("others_param", [16, 1, 4]) def test_apply_with_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - ep_size = others_param is_prefill = False - if ep_size == 1: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_allgather'] - elif ep_size < 16: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_all2allv'] - else: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_mc2'] - - forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( - ep_size, is_prefill, True), - with_quant=False, - token_dispatcher=selected_token_dispatcher) + forward_context = mock_dist_env['mock_forward_context_obj'] with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3): - expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) @@ -494,8 +439,10 @@ class TestAscendUnquantizedFusedMoEMethod: expert_map=expert_map, is_prefill=is_prefill) - expected_shape = (16, 2) + mock_moe_comm_method = mock_dist_env['mock_moe_comm_method'] + mock_moe_comm_method.fused_experts.assert_called_once() + expected_shape = (16, 2) assert result.shape == expected_shape @@ -524,10 +471,47 @@ class TestExpertsSelector: assert topk_ids.shape == (8, 2) +class TestCumsumGroupList(TestBase): + + def setUp(self): + self.active_num = 8 + self.expert_num = 128 + self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64) + self.experts[:self.active_num] = 1 + self.experts = self.experts[torch.randperm(self.expert_num)] + self.group_list = self.experts.cumsum(dim=0) + + def test_cumsum_group_list_with_type_0(self): + group_list = self.experts.cumsum(dim=0) + group_list_type = 0 + result = cumsum_group_list(group_list, group_list_type) + self.assertTrue(torch.equal(result, self.group_list)) + + def test_cumsum_group_list_with_type_1(self): + group_list = self.experts + group_list_type = 1 + result = cumsum_group_list(group_list, group_list_type) + self.assertTrue(torch.equal(result, self.group_list)) + + def test_cumsum_group_list_with_type_2(self): + tokens = torch.arange(self.expert_num, dtype=torch.int64) + group_list = torch.cat([ + tokens.reshape(self.expert_num, 1), + self.experts.reshape(self.expert_num, 1) + ], + dim=1) + group_list_type = 2 + result = cumsum_group_list(group_list, + group_list_type, + active_num=self.active_num, + expert_num=self.expert_num) + self.assertTrue(torch.equal(result, self.group_list)) + + class TestUnifiedApplyMLP(TestBase): - @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') - @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') + @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dequant_swiglu_quant') @@ -538,7 +522,7 @@ class TestUnifiedApplyMLP(TestBase): mock_get_forward_context): mock_forward_context = MagicMock() - mock_forward_context.fused_moe_state = FusedMoEState.MC2 + mock_forward_context.moe_comm_type = MoECommType.MC2 mock_get_forward_context.return_value = mock_forward_context mock_is_310p.return_value = False @@ -582,8 +566,6 @@ class TestUnifiedApplyMLP(TestBase): with_quant=True) mock_get_forward_context.assert_called() - self.assertEqual(mock_forward_context.fused_moe_state, - FusedMoEState.MC2) mock_npu_dynamic_quant.assert_called() @@ -593,7 +575,7 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') + @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -635,7 +617,7 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) - @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -695,7 +677,7 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') + @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -739,3 +721,68 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) + + @patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context") + @patch("torch_npu.npu_grouped_matmul") + @patch("torch_npu.npu_swiglu") + @patch("torch_npu.npu_grouped_matmul_swiglu_quant") + @patch("torch_npu.npu_dynamic_quant") + def test_unified_apply_mlp_with_quantization_and_fusion_mlp( + self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant, + mock_npu_swiglu, mock_npu_grouped_matmul, + mock_get_forward_context): + + mock_forward_context = MagicMock() + mock_forward_context.with_quant = True + mock_forward_context.fused_moe_state = "NOT_MC2" + mock_get_forward_context.return_value = mock_forward_context + + mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint( + -128, 127, (10, 40), + dtype=torch.int8), torch.rand( + 10, 1, + dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32)) + mock_npu_grouped_matmul.side_effect = [[ + torch.randn(10, 20, dtype=torch.bfloat16) + ]] + mock_npu_swiglu.return_value = torch.randn(10, + 40, + dtype=torch.bfloat16) + mock_npu_dynamic_quant.return_value = (torch.randint(-128, + 127, (10, 40), + dtype=torch.int8), + torch.rand(10, + 1, + dtype=torch.float32)) + + hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) + w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) + w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) + w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) + w2_scale = torch.randn(5, 20, dtype=torch.bfloat16) + w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16) + w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16) + group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) + provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) + + result = unified_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=provided_dynamic_scale, + group_list_type=1, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=None, + with_quant=True, + fusion=True) + + mock_get_forward_context.assert_called() + mock_npu_grouped_matmul.assert_called_once() + mock_npu_grouped_matmul_swiglu_quant.assert_called_once() + + self.assertTrue(mock_forward_context.with_quant) + self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.dtype, torch.bfloat16) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index c7bc657..b0c05a2 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,13 +1,18 @@ -from unittest.mock import patch +import unittest import pytest import torch +from pytest_mock import MockerFixture from vllm.model_executor.layers.layernorm import RMSNorm +from tests.ut.base import PytestBase +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -@pytest.fixture -def dummy_tensor(): - return torch.randn(4, 8, dtype=torch.float16) + +def mock_maybe_chunk_residual(x, residual): + if x.size(0) != residual.size(0): + return residual[:4] + return residual def mock_rms_norm(x, weight, eps): @@ -18,36 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps): return 2 * x, None, 2 * residual -@pytest.mark.parametrize("is_310p_return", [True, False]) -@pytest.mark.parametrize("residual", - [None, torch.randn(4, 8, dtype=torch.float32)]) -@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) -@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) -def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, - residual, dummy_tensor): +def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset, + epsilon): + x_out = 2 * x + residual_out = 2 * residual + x_out_quant = x_out.to(torch.int8) + residual_out_quant = residual_out.to(torch.int8) + return x_out_quant, None, residual_out_quant - with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): - layer = RMSNorm(hidden_size=32, eps=1e-05) + +class TestAscendRMSNorm(PytestBase): + + @pytest.fixture(autouse=True) + def context(self, mocker: MockerFixture): + mocker.patch("torch.ops.vllm.maybe_chunk_residual", + side_effect=mock_maybe_chunk_residual) + mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) + mocker.patch("torch_npu.npu_add_rms_norm", + side_effect=mock_add_rms_norm) + mocker.patch("torch_npu.npu_add_rms_norm_quant", + side_effect=mock_add_rms_norm_quant) + mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done", + side_effect=lambda x: None) + + # Test case for the most common and basic scenario + @pytest.mark.parametrize( + "residual", [None, torch.randn(4, 8, dtype=torch.float16)]) + def test_forward_oot_basic(self, residual): + layer = RMSNorm(hidden_size=8, eps=1e-05) + x = torch.randn(4, 8, dtype=torch.float16) if residual is not None: - out_x, out_residual = layer.forward_oot(dummy_tensor, residual) + x_out, residual_out = layer.forward_oot(x, residual) - if is_310p_return: - expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype) - expected_out_x = expected_arg_x + 1 - expected_out_residual = expected_arg_x.to(residual.dtype) + x_out_expected = 2 * x + residual_out_expected = 2 * residual - mock_rmsnorm.assert_called_once() - assert torch.allclose(out_x, expected_out_x) - assert torch.allclose(out_residual, expected_out_residual) - else: - expected_out_x = 2 * dummy_tensor - expected_out_residual = 2 * residual - mock_add_rmsnorm.assert_called_once() - assert torch.allclose(out_x, expected_out_x) - assert torch.allclose(out_residual, expected_out_residual) + assert torch.allclose(x_out, x_out_expected) + assert torch.allclose(residual_out, residual_out_expected) else: - out_x = layer.forward(dummy_tensor, residual) - expected_out_x = dummy_tensor + 1 + x_out = layer.forward(x, residual) + x_out_expected = x + 1 - mock_rmsnorm.assert_called_once() - assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(x_out, x_out_expected) + + # Test case for flashcomm_v1 scenario + def test_forward_oot_with_flashcomm_v1(self): + layer = RMSNorm(hidden_size=512, eps=1e-05) + x = torch.randn(4, 512, dtype=torch.bfloat16) + residual = torch.randn(16, 512, dtype=torch.bfloat16) + + x_out, residual_out = layer.forward_oot(x, residual) + + x_out_expected = 2 * x + residual_out_expected = 2 * residual[:4] + + assert residual_out.size(0) == 4 + assert torch.allclose(x_out, x_out_expected) + assert torch.allclose(residual_out, residual_out_expected) + + # Test case for addrmsnorm + w8a8 quant fusion + def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture): + mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p") + mock_is_310p.return_value = False + mock_get_forward_context = mocker.patch( + "vllm_ascend.ops.layernorm.get_forward_context") + + # Simulating a scenario with quant_fusion enabled + mock_forward_context = mocker.MagicMock() + + mock_model_instance = mocker.MagicMock() + mock_forward_context.model_instance = mock_model_instance + mock_model_instance.model.layers = [ + mocker.MagicMock() for _ in range(2) + ] + + mock_layer_0 = mock_model_instance.model.layers[0] + mock_layer_0.self_attn.qkv_proj = mocker.MagicMock() + mock_layer_0.mlp.gate_up_proj = mocker.MagicMock() + + mock_layer_1 = mock_model_instance.model.layers[1] + mock_layer_1.self_attn.qkv_proj = mocker.MagicMock() + mock_layer_1.mlp.gate_up_proj = mocker.MagicMock() + + mock_quant_method_0_qkv = mocker.MagicMock() + mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod() + mock_quant_method_0_gate_up = mocker.MagicMock() + mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod() + mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv + mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up + + mock_quant_method_1_qkv = mocker.MagicMock() + mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod() + mock_quant_method_1_gate_up = mocker.MagicMock() + mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod() + mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv + mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up + + mock_get_forward_context.return_value = mock_forward_context + + mock_forward_context.addrmsnorm_quant_fusion_enabled = True + mock_forward_context.prefetch_mlp_enabled = False + mock_forward_context.layer_idx = 0 + mock_forward_context.num_hidden_layers = 2 + mock_forward_context.fusion_linear = "gate_up_dense" + + # Ensure fusion and layer_idx increment are handled correctly + x = torch.randn(4, 8, dtype=torch.float16) + residual = torch.randn(4, 8, dtype=torch.float16) + layer = RMSNorm(hidden_size=8, eps=1e-05) + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 1 + assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.layer_idx == 1 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 2 + assert mock_forward_context.fusion_linear == "gate_up_dense" + assert mock_forward_context.layer_idx == 1 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 3 + assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.layer_idx == 2 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 4 + assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.layer_idx == 2 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 28b26b7..e22d7ca 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -1,363 +1,96 @@ import os import unittest from unittest import mock +from unittest.mock import MagicMock, patch import torch -from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, - AscendMlpMergedColumnParallelLinear, - AscendMlpRowParallelLinear, LinearBase, - QuantizationConfig) +from vllm_ascend import ascend_config +from vllm_ascend.distributed import parallel_state +from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear, + AscendRowParallelLinear) -class TestAscendMlpRowParallelLinear(unittest.TestCase): +class BaseLinearTest(unittest.TestCase): def setUp(self): - os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" - self.tensor_parallel_world_size = 2 - self.tensor_parallel_rank = 0 - self.mlp_tensor_parallel_world_size = 2 - self.mlp_tensor_parallel_rank = 1 + self.mock_group = mock.MagicMock() + self.mock_group.world_size = 2 + self.mock_group.rank_in_group = 0 - self.get_tensor_model_parallel_world_size_patch = mock.patch( - 'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size', - return_value=self.tensor_parallel_world_size) - self.get_tensor_model_parallel_rank_patch = mock.patch( - 'vllm_ascend.ops.linear.get_tensor_model_parallel_rank', - return_value=self.tensor_parallel_rank) - self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch( - 'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size', - return_value=self.mlp_tensor_parallel_world_size) - self.get_mlp_tensor_model_parallel_rank_patch = mock.patch( - 'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank', - return_value=self.mlp_tensor_parallel_rank) + parallel_state._MLP_TP = self.mock_group + parallel_state._OTP = self.mock_group - self.get_tensor_model_parallel_world_size_mock = \ - self.get_tensor_model_parallel_world_size_patch.start() - self.get_tensor_model_parallel_rank_mock = \ - self.get_tensor_model_parallel_rank_patch.start() - self.get_mlp_tensor_model_parallel_world_size_mock = \ - self.get_mlp_tensor_model_parallel_world_size_patch.start() - self.get_mlp_tensor_model_parallel_rank_mock = \ - self.get_mlp_tensor_model_parallel_rank_patch.start() + self.mock_ascend_config = MagicMock() + self.mock_ascend_config.oproj_tensor_parallel_size = 2 - self.split_tensor_along_last_dim_patch = mock.patch( - 'vllm_ascend.ops.linear.split_tensor_along_last_dim', - return_value=(torch.randn(10, 8), torch.randn(10, 8))) - self.tensor_model_parallel_all_reduce_patch = mock.patch( - 'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce', - return_value=torch.randn(10, 8)) - self.tensor_model_parallel_all_reduce_mock = \ - self.tensor_model_parallel_all_reduce_patch.start() - self.split_tensor_along_last_dim_mock = \ - self.split_tensor_along_last_dim_patch.start() - self.get_mlp_tp_group_patch = \ - mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group') - self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start() - self.get_mlp_tp_group_mock.return_value = mock.MagicMock() - self.get_mlp_tp_group_mock.return_value.reduce_scatter = \ - mock.MagicMock() + self.patches = [ + patch("vllm_ascend.ascend_config.get_ascend_config", + return_value=self.mock_ascend_config), + patch("vllm_ascend.distributed.parallel_state.get_otp_group", + return_value=self.mock_group), + patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group", + return_value=self.mock_group), + patch("vllm_ascend.ops.linear_op.get_tp_group", + return_value=self.mock_group), + patch( + "vllm.distributed.parallel_state.get_tp_group", + return_value=self.mock_group, + ), + patch("vllm_ascend.utils.mlp_tp_enable", return_value=True), + patch("vllm_ascend.utils.oproj_tp_enable", return_value=True) + ] + + for p in self.patches: + p.start() def tearDown(self): - self.get_tensor_model_parallel_world_size_patch.stop() - self.get_tensor_model_parallel_rank_patch.stop() - self.get_mlp_tensor_model_parallel_world_size_patch.stop() - self.get_mlp_tensor_model_parallel_rank_patch.stop() - self.split_tensor_along_last_dim_patch.stop() - self.tensor_model_parallel_all_reduce_patch.stop() - self.get_mlp_tp_group_patch.stop() + for p in self.patches: + p.stop() - def test_init_with_down_proj_prefix(self): - layer = AscendMlpRowParallelLinear(input_size=16, - output_size=8, - prefix="down_proj") - self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size) - self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank) - self.assertTrue(layer.enable_mlp_optimze) - def test_forward_with_mlp_optimize(self): - layer = AscendMlpRowParallelLinear( +class TestAscendRowParallelLinear(BaseLinearTest): + + def test_mlp_optimize(self): + os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + + linear = AscendRowParallelLinear( input_size=16, output_size=8, prefix="down_proj", - input_is_parallel=False, ) - input_tensor = torch.randn(16, 8) # (batch_size, input_size) - layer(input_tensor) + self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP) - self.split_tensor_along_last_dim_mock.assert_called_once_with( - input_tensor, num_partitions=layer.tp_size) + input_tensor = torch.randn(16, 8) + linear(input_tensor) - def test_forward_without_mlp_optimize(self): - layer = AscendMlpRowParallelLinear( + def test_oproj_tp(self): + ascend_config._ASCEND_CONFIG = MagicMock() + ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2 + + linear = AscendRowParallelLinear( input_size=16, output_size=8, - prefix="other", - input_is_parallel=False, + prefix="o_proj", ) + self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP) + input_tensor = torch.randn(16, 8) - layer(input_tensor) + linear(input_tensor) - self.split_tensor_along_last_dim_mock.assert_called_once_with( - input_tensor, num_partitions=layer.tp_size) - self.tensor_model_parallel_all_reduce_mock.assert_called_once() - def test_skip_bias_add(self): - layer = AscendMlpRowParallelLinear( +class TestAscendMergedColumnParallelLinear(BaseLinearTest): + + def test_merged_mlp_tp_init(self): + os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + + linear = AscendMergedColumnParallelLinear( input_size=16, - output_size=8, - skip_bias_add=True, + output_sizes=[8, 8], + prefix="gate_up_proj", ) - input_tensor = torch.randn(16, 8) - output, bias = layer(input_tensor) - - self.assertIsNotNone(bias) - - def test_no_reduce_results(self): - layer = AscendMlpRowParallelLinear(input_size=16, - output_size=8, - reduce_results=False, - bias=False) - input_tensor = torch.randn(16, 8) - layer(input_tensor) - - self.tensor_model_parallel_all_reduce_mock.assert_not_called() - - def test_input_not_parallel(self): - layer = AscendMlpRowParallelLinear(input_size=16, - output_size=8, - input_is_parallel=False) - input_tensor = torch.randn(16, 8) - layer(input_tensor) - - self.split_tensor_along_last_dim_mock.assert_called_once() - - def test_exception_when_reduce_false_and_bias(self): - with self.assertRaises(ValueError): - AscendMlpRowParallelLinear(input_size=16, - output_size=8, - reduce_results=False, - bias=True, - skip_bias_add=False) + self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP) -class TestAscendMlpColumnParallelLinear(unittest.TestCase): - - def setUp(self): - os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" - # Mock distributed functions - self.mlp_tp_size_patch = \ - mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size') - self.mlp_tp_size_mock = self.mlp_tp_size_patch.start() - self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group - - self.mlp_tp_rank_patch = \ - mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank') - self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start() - self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank - - self.tp_size_patch = \ - mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size') - self.tp_size_mock = self.tp_size_patch.start() - self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group - - self.tp_rank_patch = \ - mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank') - self.tp_rank_mock = self.tp_rank_patch.start() - self.tp_rank_mock.return_value = 1 # Current GPU rank - - # Mock divide function (assumed to be in your module) - self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide') - self.divide_mock = self.divide_patch.start() - self.divide_mock.side_effect = lambda x, y: x // y # Simulate division - - # Mock QuantizationConfig and QuantMethod - self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig) - - # Mock LinearBase initialization - self.linear_base_init_patch = mock.patch.object( - LinearBase, "__init__", side_effect=self.mock_linear_base_init) - self.linear_base_init_patch.start() - - self.quant_method_mock = mock.MagicMock() - - def mock_linear_base_init(self, instance, *args, **kwargs): - instance.quant_method = self.quant_method_mock - instance.params_dtype = mock.MagicMock() - - instance.input_size = 16 - instance.output_size = 8 - instance.output_size_per_partition = 4 - instance.params_dtype = torch.float32 - - def tearDown(self): - self.mlp_tp_size_patch.stop() - self.mlp_tp_rank_patch.stop() - self.tp_size_patch.stop() - self.tp_rank_patch.stop() - self.divide_patch.stop() - self.linear_base_init_patch.stop() - - def test_mlp_optimize_initialization(self): - # Test when prefix contains "gate_up_proj" - with mock.patch.object(torch.nn.Module, 'register_parameter'): - layer = AscendMlpColumnParallelLinear( - input_size=16, - output_size=8, - prefix="model.layers.0.gate_up_proj", - bias=False, - ) - - # Verify MLP optimization flags - self.assertTrue(layer.enable_mlp_optimze) - self.assertEqual(layer.tp_size, 2) - self.assertEqual(layer.tp_rank, 0) - self.assertEqual(layer.input_size_per_partition, 16) - self.assertEqual(layer.output_size_per_partition, 4) - - # Check quant_method.create_weights was called - self.quant_method_mock.create_weights.assert_called_once() - - def test_regular_parallel_initialization(self): - # Test when prefix does NOT contain "gate_up_proj" - with mock.patch.object(torch.nn.Module, 'register_parameter'): - layer = AscendMlpColumnParallelLinear( - input_size=16, - output_size=8, - prefix="model.layers.0.q_proj", - quant_config=self.quant_config_mock, - bias=False, - ) - - # Verify regular TP flags - self.assertFalse(layer.enable_mlp_optimze) - self.assertEqual(layer.tp_size, 4) - self.assertEqual(layer.tp_rank, 1) - self.assertEqual(layer.input_size_per_partition, 16) - self.assertEqual(layer.output_size_per_partition, 4) - # Check quant_method.create_weights was called - self.quant_method_mock.create_weights.assert_called_once() - - def test_output_sizes_handling(self): - # Test when output_sizes is provided - with mock.patch.object(torch.nn.Module, 'register_parameter'): - layer = AscendMlpColumnParallelLinear( - input_size=16, - output_size=8, - output_sizes=[4, 4], - prefix="model.layers.0.qkv_proj", - quant_config=self.quant_config_mock, - bias=False, - ) - - # Verify output_partition_sizes - self.assertEqual(layer.output_partition_sizes, [2]) - - -class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase): - - def setUp(self): - os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" - # Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size - self.mlp_world_size_patch = \ - mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2) - self.tensor_world_size_patch = \ - mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2) - self.mlp_world_size_patch.start() - self.tensor_world_size_patch.start() - - # Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank - self.mlp_rank_patch = \ - mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0) - self.tensor_rank_patch = \ - mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0) - self.mlp_rank_patch.start() - self.tensor_rank_patch.start() - - # Mock all_gather methods - self.get_mlp_tp_group_patch = \ - mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group') - self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start() - self.get_mlp_tp_group_mock.return_value = mock.MagicMock() - self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock() - self.tensor_model_parallel_all_gather_patch = mock.patch( - 'vllm_ascend.ops.linear.tensor_model_parallel_all_gather', - return_value=torch.randn(10, 8)) - self.tensor_model_parallel_all_gather_mock = \ - self.tensor_model_parallel_all_gather_patch.start() - - # Mock AscendMlpColumnParallelLinear's __init__ - self.linear_init_patch = mock.patch.object( - AscendMlpColumnParallelLinear, - "__init__", - side_effect=self.mock_linear_init) - self.linear_init_patch.start() - - # Create mock objects - self.quant_method_mock = mock.MagicMock() - self.apply_output = torch.randn(2, 8) - - self.quant_method_mock.apply.return_value = self.apply_output - - def mock_linear_init(self, instance, *args, **kwargs): - torch.nn.Module.__init__(instance) - # Set quant_method and other attributes - instance.quant_method = self.quant_method_mock - instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias - instance.input_size = 16 - instance.output_size = 8 - instance.gather_output = False - instance.skip_bias_add = False - instance.return_bias = True - - def test_forward_with_enable_mlp_optimze(self): - # Setup input - input_tensor = torch.randn(1, 16) - - # Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True - layer = AscendMlpMergedColumnParallelLinear(input_size=16, - output_sizes=[8], - bias=True, - gather_output=False, - skip_bias_add=False, - params_dtype=torch.float32, - quant_config=None, - prefix="other_proj") - - # Call forward - output, bias = layer(input_tensor) - - # Validate calls - self.assertEqual(output.shape, self.apply_output.shape) - - def test_forward_without_enable_mlp_optimze(self): - # Setup input - input_tensor = torch.randn(1, 16) - - # Create instance with prefix not containing "gate_up_proj" - layer = AscendMlpMergedColumnParallelLinear(input_size=16, - output_sizes=[8], - bias=True, - gather_output=False, - skip_bias_add=False, - params_dtype=torch.float32, - quant_config=None, - prefix="other_proj") - - # Call forward - output, bias = layer(input_tensor) - - # Validate calls - self.quant_method_mock.apply.assert_called_once_with( - layer, input_tensor, layer.bias) - self.tensor_model_parallel_all_gather_mock.assert_not_called() - self.assertEqual(output.shape, self.apply_output.shape) - - def tearDown(self): - self.linear_init_patch.stop() - self.mlp_world_size_patch.stop() - self.tensor_world_size_patch.stop() - self.mlp_rank_patch.stop() - self.tensor_rank_patch.stop() - self.get_mlp_tp_group_mock.stop() - self.tensor_model_parallel_all_gather_mock.stop() +if __name__ == '__main__': + unittest.main() diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py new file mode 100644 index 0000000..97aea93 --- /dev/null +++ b/tests/ut/ops/test_moe_comm_method.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from tests.ut.base import TestBase +from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, + AlltoAllCommImpl, MC2CommImpl) + + +class TestMoECommMethod(TestBase): + + def setUp(self): + # Mock FusedMoEConfig + self.moe_config = MagicMock(spec=FusedMoEConfig) + self.moe_config.num_experts = 8 + self.moe_config.num_local_experts = 2 + self.moe_config.experts_per_token = 2 + self.moe_config.tp_group = MagicMock() + self.moe_config.tp_group.device_group = MagicMock() + self.moe_config.dp_size = 1 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + self.moe_config.dp_group = MagicMock() + self.moe_config.num_global_redundant_experts = 0 + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") + def test_all_gather_comm_impl(self, mock_token_dispatcher, + mock_prepare_finalize, + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "all_gather" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), None) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_token_dispatcher.return_value = mock_td_instance + + # Create instance + comm_impl = AllGatherCommImpl(self.moe_config) + + # Test prepare method + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + + # Verify prepare was called with correct arguments + mock_pf_instance.prepare.assert_called_once_with( + hidden_states, router_logits, False, False, False, None) + + # Test finalize method + comm_impl.finalize(h_out, reduce_results=True) + mock_pf_instance.finalize.assert_called_once_with(h_out, True) + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2") + def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "mc2" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), + torch.tensor([1, 0, 1, 0])) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_token_dispatcher.return_value = mock_td_instance + + # Create instance + comm_impl = MC2CommImpl(self.moe_config) + + # Test prepare method + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + + # Verify prepare was called with correct arguments + mock_pf_instance.prepare.assert_called_once_with( + hidden_states, router_logits, False, False, False, None) + + # Test finalize method + comm_impl.finalize(h_out, reduce_results=True) + mock_pf_instance.finalize.assert_called_once_with(h_out, True) + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV") + def test_alltoall_comm_impl(self, mock_token_dispatcher, + mock_prepare_finalize, + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "alltoall" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), None) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_token_dispatcher.return_value = mock_td_instance + + # Create instance + comm_impl = AlltoAllCommImpl(self.moe_config) + + # Test prepare method + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + + # Verify prepare was called with correct arguments + mock_pf_instance.prepare.assert_called_once_with( + hidden_states, router_logits, False, False, False, None) + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") + @patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp") + def test_fused_experts_method(self, mock_unified_apply_mlp, + mock_token_dispatcher, mock_prepare_finalize, + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "all_gather" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), None) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_td_instance.token_dispatch.return_value = { + "hidden_states": torch.randn(6, 8), + "group_list": torch.tensor([2, 2, 2]), + "group_list_type": 1 + } + mock_td_instance.token_combine.return_value = torch.randn(4, 8) + mock_token_dispatcher.return_value = mock_td_instance + + # Mock unified_apply_mlp + mock_unified_apply_mlp.return_value = torch.randn(6, 8) + + # Create instance + comm_impl = AllGatherCommImpl(self.moe_config) + + # Test fused_experts method + hidden_states = torch.randn(4, 8).contiguous() + w1 = torch.randn(16, 8).contiguous() + w2 = torch.randn(16, 8).contiguous() + topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], + [0.6, 0.4]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]]) + row_idx = torch.arange(4) + + # Make sure tensors are contiguous and have correct strides + hidden_states = hidden_states.contiguous() + w1 = w1.contiguous() + w2 = w2.contiguous() + + result = comm_impl.fused_experts(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + activation="silu") + + # Verify result shape + self.assertEqual(result.shape, (4, 8)) + + # Verify token_dispatch was called + mock_td_instance.token_dispatch.assert_called_once() + + # Verify unified_apply_mlp was called + mock_unified_apply_mlp.assert_called_once() + + # Verify token_combine was called + mock_td_instance.token_combine.assert_called_once() diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index eb48c81..21d95bb 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -3,12 +3,18 @@ import unittest from unittest.mock import MagicMock, PropertyMock, patch import torch +from transformers.configuration_utils import PretrainedConfig +from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from tests.ut.base import TestBase +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled +MODEL = "Qwen3-0.6B" +MAX_NUM_BATCHED_TOKEND = 10000 + class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): @@ -88,11 +94,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.mock_self.cos_sin_cache = self.cos_sin_cache self.mock_self.is_neox_style = self.is_neox_style - @patch('torch.ops._C') + @patch('torch.ops._C_ascend') @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=True) @patch('torch.ops._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, mock_custom_enabled, mock_is_310p, mock__c): @@ -102,9 +112,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase): # Setup mock for custom kernel path mock__c.rotary_embedding.return_value = self.query, self.key - - result_q, result_k = self.layer.forward(self.positions, self.query, - self.key) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) mock__c.rotary_embedding.assert_called_once() self.assertEqual(result_q.shape, self.query.shape) @@ -113,6 +129,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_contiguous(self, mock_npu_rotary, mock_custom_enabled): mock_config = MagicMock() @@ -121,15 +141,25 @@ class TestAscendRotaryEmbedding(unittest.TestCase): # Test contiguous path when custom is disabled non_contig_query = self.query.transpose(0, 1) non_contig_key = self.key.transpose(0, 1) - - result_q, result_k = self.layer.forward(self.positions, - non_contig_query, - non_contig_key) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward(self.positions, + non_contig_query, + non_contig_key) mock_npu_rotary.assert_called_once() self.assertEqual(result_q.shape, non_contig_query.shape) self.assertEqual(result_k.shape, non_contig_key.shape) + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_with_offsets(self): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False @@ -137,26 +167,78 @@ class TestAscendRotaryEmbedding(unittest.TestCase): # Test that NotImplementedError is raised when offsets is provided offsets = torch.tensor([1, 2, 3]) with self.assertRaises(NotImplementedError): - self.layer.forward(self.positions, self.query, self.key, offsets) + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + self.layer.forward(self.positions, self.query, self.key, + offsets) @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=False) @patch('torch_npu._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, mock_custom_enabled): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False # Test neox_style override - result_q, result_k = self.layer.forward(self.positions, - self.query, - self.key, - is_neox_style_override=False) - + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward( + self.positions, + self.query, + self.key, + is_neox_style_override=False) # Check that neox_style=False was passed to the NPU function args, kwargs = mock_npu_rotary.call_args self.assertFalse(args[-1]) + @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) + def test_rope_forward_oot_rotary_dim_less_than_head_size( + self, mock_npu_rotary, mock_custom_enabled): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + # test case when rotary_dim < head_size + org_rotary_dim = self.layer.rotary_dim + self.layer.rotary_dim = self.layer.head_size // 2 + + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL, + tokenizer=MODEL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward(self.positions, self.query, + self.key) + + mock_npu_rotary.assert_called_once() + self.assertEqual(result_q.shape, self.query.shape) + self.assertEqual(result_k.shape, self.key.shape) + + # restore rotary_dim + self.layer.rotary_dim = org_rotary_dim + class MockRopeModule: @@ -207,28 +289,6 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): assert q_pe.shape == self.query.shape assert k_pe.shape == self.key.shape - @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_native_rope_deepseek_forward_cache_handling( - self, mock_npuplatform, mock_rope_forward_oot): - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - self.layer.max_seq_len = 1024 - # Test cache situation is true - with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache: - mock_rope_forward_oot.return_value = (self.query, self.key) - - q_pe, k_pe = self.layer.forward(self.positions, - self.query, - self.key, - max_seq_len=2048) - mock_set_cache.assert_called_once() - assert q_pe.shape == self.query.shape - assert k_pe.shape == self.key.shape - @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') @patch("vllm.platforms.current_platform.device_type", new=torch.device("cpu")) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 9de8a13..cc2d307 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -20,10 +20,10 @@ from unittest.mock import MagicMock, PropertyMock, patch import torch from tests.ut.base import TestBase -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + +from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip AscendSocVersion, TokenDispatcherWithAll2AllV, - TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers, - _register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers) + TokenDispatcherWithAllGather, TokenDispatcherWithMC2) class TestTokenDispatcherWithMC2(TestBase): @@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase): self.mc2_group.rank_in_group = 0 self.mc2_group.world_size = 8 self.mc2_group_patch = patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group", + "vllm_ascend.ops.moe.token_dispatcher.get_mc2_group", return_value=self.mc2_group) self.mc2_group_patch.start() @@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase): # Mock get_ascend_soc_version() self.ascend_soc_version_patch = patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version", + "vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=AscendSocVersion.A3) self.ascend_soc_version_patch.start() @@ -98,7 +98,7 @@ class TestTokenDispatcherWithMC2(TestBase): self.row_idx, expert_map) mock_dispatch.assert_called_once() self.assertEqual(output["group_list_type"], - 1) # group_list_type == 1 + 0) # group_list_type == 0 def test_token_dispatch_with_shared_experts_and_quant(self): self.shared_experts = MagicMock() @@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase): self.dispatcher = TokenDispatcherWithAllGather(**kwargs) # Mock NPU functions - self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing') - self.mock_moe_init_routing = self.patcher_moe_init_routing.start() - self.mock_moe_init_routing.return_value = ( + self.patcher_npu_moe_init_routing_v2 = patch( + 'torch_npu.npu_moe_init_routing_v2') + self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start( + ) + self.mock_npu_moe_init_routing_v2.return_value = ( torch.randn(6, 128), # sorted_hidden_states torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx - torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx - ) - - self.patcher_moe_compute_expert_tokens = patch( - 'torch_npu.npu_moe_compute_expert_tokens') - self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start( - ) - self.mock_moe_compute_expert_tokens.return_value = torch.tensor( - [3, 3]) # expert_tokens - - self.patcher_moe_finalize_routing = patch( - 'torch_npu.npu_moe_finalize_routing') - self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( - ) - self.mock_moe_finalize_routing.return_value = torch.randn(3, 128) + torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx + torch.tensor([0, 1, 0, 1, 0, 1])) self.row_idx = torch.arange(10, dtype=torch.int32) + self.patcher_npu_moe_token_unpermute = patch( + 'torch_npu.npu_moe_token_unpermute') + self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start( + ) + self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128) def tearDown(self): - self.patcher_moe_init_routing.stop() - self.patcher_moe_compute_expert_tokens.stop() - self.patcher_moe_finalize_routing.stop() + self.patcher_npu_moe_init_routing_v2.stop() + self.patcher_npu_moe_token_unpermute.stop() def test_token_dispatch_without_expert_map(self): hidden_states = torch.randn(3, 128) @@ -207,12 +200,27 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_ids, self.row_idx, None) # Verify npu_moe_init_routing is called - self.mock_moe_init_routing.assert_called_once() - args, kwargs = self.mock_moe_init_routing.call_args + self.mock_npu_moe_init_routing_v2.assert_called_once() + args, kwargs = self.mock_npu_moe_init_routing_v2.call_args - self.assertEqual(results["group_list_type"], 0) + self.assertEqual(results["group_list_type"], 1) - def test_token_dispatch_with_quant(self): + def test_token_dispatch_with_expert_map(self): + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + + results = self.dispatcher.token_dispatch(hidden_states, topk_weights, + topk_ids, self.row_idx, None) + + # Verify npu_moe_init_routing is called + self.mock_npu_moe_init_routing_v2.assert_called_once() + args, kwargs = self.mock_npu_moe_init_routing_v2.call_args + + self.assertEqual(results["group_list_type"], 1) + + def test_token_dispatch_without_quant(self): kwargs = { "apply_router_weight_on_input": False, "top_k": 2, @@ -230,7 +238,33 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights, topk_ids, self.row_idx, None) - self.assertEqual(results["group_list_type"], 0) + self.assertEqual(results["group_list_type"], 1) + + def test_token_dispatch_with_quant(self): + kwargs = { + "apply_router_weight_on_input": False, + "top_k": 2, + "max_num_tokens": 100, + "ep_size": 2, + "num_experts": 128, + } + self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) + + hidden_states = torch.randn(3, 128) + topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + + results = self.dispatcher_quant.token_dispatch(hidden_states, + topk_weights, + topk_ids, + self.row_idx, + None, + with_quant=True) + + self.assertIsNotNone(results["hidden_states"]) + self.assertIsNotNone(results["group_list"]) + self.assertIsNotNone(results["dynamic_scale"]) + self.assertEqual(results["group_list_type"], 1) def test_token_combine_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) @@ -242,9 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase): hidden_states = torch.randn(6, 128) final_hidden_states = self.dispatcher.token_combine(hidden_states) - - # Verify index_add_ is applied correctly - self.assertEqual(final_hidden_states.shape, (3, 128)) + self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_combine_without_expert_map(self): self.dispatcher.with_quant = False @@ -260,10 +292,10 @@ class TestTokenDispatcherWithAllGather(TestBase): final_hidden_states = self.dispatcher.token_combine(hidden_states) # Verify npu_moe_finalize_routing is called - self.mock_moe_finalize_routing.assert_called_once() - args, kwargs = self.mock_moe_finalize_routing.call_args + self.mock_npu_moe_token_unpermute.assert_called_once() + args, kwargs = self.mock_npu_moe_token_unpermute.call_args - self.assertEqual(final_hidden_states.shape, (3, 128)) + self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_dispatch_with_router_weight(self): self.dispatcher.apply_router_weight_on_input = True @@ -315,7 +347,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16) # Mock async_all_to_all - patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all') + patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all') self.mock_async_all_to_all = patcher6.start() self.addCleanup(patcher6.stop) self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16), @@ -323,7 +355,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): # Mock gather_from_sequence_parallel_region patcher7 = patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region' + 'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region' ) self.mock_gather_from_sequence_parallel_region = patcher7.start() self.addCleanup(patcher7.stop) @@ -488,119 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["group_list"]) self.assertEqual(result["group_list_type"], 1) - - -class TestDispatcherRegistry(TestBase): - - def setUp(self): - _Dispatchers.clear() - - def tearDown(self): - _Dispatchers.clear() - - def test_register_and_get_token_dispatcher(self): - mock_dispatcher = MagicMock() - mock_dispatcher.__class__.__name__ = "MockDispatcher" - - _register_token_dispatcher(mock_dispatcher) - - self.assertIn("MockDispatcher", _Dispatchers) - self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher) - - retrieved_dispatcher = get_token_dispatcher("MockDispatcher") - self.assertIs(retrieved_dispatcher, mock_dispatcher) - - self.assertIsNone(get_token_dispatcher("NonExistentDispatcher")) - - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) - def test_setup_token_dispatchers_ep_size_1_creates_allgather( - self, mock_register, mock_allgather_class): - kwargs = {"top_k": 2, "num_experts": 8} - mock_instance = MagicMock() - mock_allgather_class.return_value = mock_instance - - self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers) - - setup_token_dispatchers(ep_size=1, **kwargs) - - mock_allgather_class.assert_called_once_with(**kwargs) - mock_register.assert_called_once_with(mock_instance) - - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) - def test_setup_token_dispatchers_ep_size_2_creates_all2allv( - self, mock_register, mock_all2allv_class): - kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2} - mock_instance = MagicMock() - mock_all2allv_class.return_value = mock_instance - - self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers) - - setup_token_dispatchers(ep_size=2, **kwargs) - - mock_all2allv_class.assert_called_once_with(**kwargs) - mock_register.assert_called_once_with(mock_instance) - - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) - def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2( - self, mock_register, mock_mc2_class, mock_all2allv_class): - kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} - mock_all2allv_instance = MagicMock() - mock_mc2_instance = MagicMock() - mock_all2allv_class.return_value = mock_all2allv_instance - mock_mc2_class.return_value = mock_mc2_instance - - self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers) - self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers) - - setup_token_dispatchers(ep_size=16, **kwargs) - - mock_all2allv_class.assert_called_once_with(**kwargs) - mock_mc2_class.assert_called_once_with(**kwargs) - self.assertEqual(mock_register.call_count, 2) - mock_register.assert_any_call(mock_all2allv_instance) - mock_register.assert_any_call(mock_mc2_instance) - - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) - def test_setup_token_dispatchers_ep_size_16_skips_if_exist( - self, mock_register, mock_mc2_class, mock_all2allv_class): - kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} - mock_existing_all2allv = MagicMock() - mock_existing_mc2 = MagicMock() - _Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv - _Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2 - - setup_token_dispatchers(ep_size=16, **kwargs) - - mock_all2allv_class.assert_not_called() - mock_mc2_class.assert_not_called() - mock_register.assert_not_called() - self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"], - mock_existing_all2allv) - self.assertIs(_Dispatchers["TokenDispatcherWithMC2"], - mock_existing_mc2) diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 5378b19..d137985 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch import torch +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) @@ -31,6 +32,9 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): self.embedding_dim = 10 self.org_num_embeddings = 40 self.padding_size = 8 + mock_vllm_config = MagicMock() + mock_vllm_config.additional_config = {} + init_ascend_config(mock_vllm_config) def _create_layer(self): # Patch methods and dependencies for VocabParallelEmbedding @@ -206,7 +210,15 @@ class TestAscendLogitsProcessor(unittest.TestCase): return_value=True), patch( "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all", - return_value=torch.randn(1, self.vocab_size)) + return_value=torch.randn(1, self.vocab_size)), + patch( + "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather", + return_value=torch.randn(1, self.vocab_size)), + patch( + "vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config", + return_value=MagicMock(max_num_batched_tokens=1000, + max_model_len=512, + enable_chunked_prefill=False)) ] for p in self.patches: diff --git a/tests/ut/patch/worker/patch_common/test_patch_linear.py b/tests/ut/patch/worker/patch_common/test_patch_linear.py deleted file mode 100644 index b7fbbc4..0000000 --- a/tests/ut/patch/worker/patch_common/test_patch_linear.py +++ /dev/null @@ -1,167 +0,0 @@ -from importlib import reload - -import pytest -import torch -import vllm -from pytest_mock import MockerFixture - -import vllm_ascend.envs as envs_ascend -from tests.ut.base import PytestBase -from vllm_ascend.patch.worker.patch_common import patch_linear - - -class TestAscendRowParallelLinear(PytestBase): - - def init_row_parallel_linear(self, mocker: MockerFixture): - mocker.patch( - "vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__", - return_value=None, - ) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - return patch_linear.AscendRowParallelLinear( - input_size=128, - output_size=256, - ) - - @pytest.mark.parametrize( - "version, expected", - [ - ("1.0.0", 1), - ("2.1.0", 1), - ], - ) - def test_get_hcomm_info(self, version, expected, mocker: MockerFixture): - mock_group = mocker.MagicMock() - backend = mocker.MagicMock() - backend.get_hccl_comm_name = lambda x: x - mock_group._get_backend = lambda x: backend - mock_group.get_hccl_comm_name = lambda x: x - mocker.patch("torch.distributed.get_rank", return_value=1) - mocker.patch( - "torch.distributed.get_global_rank", - return_value=0, - ) - mocker.patch("torch.__version__", new=version) - hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info( - mock_group) - assert hcomm_info == expected - - @pytest.mark.parametrize( - "skip_bias_add, return_bias, bias, expected", - [ - (True, False, torch.tensor(1.0), torch.tensor(14.0)), - (False, True, torch.tensor(1.0), (torch.tensor(14.0), None)), - ( - True, - True, - torch.tensor(1.0), - (torch.tensor(14.0), torch.tensor(1.0)), - ), - ], - ) - def test_forward( - self, - skip_bias_add, - return_bias, - bias, - expected, - mocker: MockerFixture, - ): - mocker_tp_group = mocker.MagicMock() - mocker_tp_group.device_group = mocker.MagicMock() - row_parallel_linear = self.init_row_parallel_linear(mocker) - row_parallel_linear.__dict__["tp_rank"] = 0 - row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add - row_parallel_linear.__dict__["return_bias"] = return_bias - row_parallel_linear.__dict__["bias"] = bias - row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock() - row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa - row_parallel_linear.__dict__[ - "calc_output"] = lambda x: x.matmul( # noqa - torch.tensor([1.0, 2.0])) - ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0])) - if isinstance(ret, tuple): - assert torch.allclose(ret[0], expected[0]) - if ret[1] is None: - assert ret[1] == expected[1] - else: - assert torch.allclose(ret[1], expected[1]) - else: - assert torch.allclose(ret, expected) - - @pytest.mark.parametrize( - "input_is_parallel, expected", - [ - (True, torch.tensor([10.0, 2.0])), - (False, torch.tensor([10.0])), - ], - ) - def test_calc_input( - self, - input_is_parallel, - expected, - mocker: MockerFixture, - ): - row_parallel_linear = self.init_row_parallel_linear(mocker) - row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel - input_tensor = torch.Tensor([10, 2]) - mocker.patch( - "vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa - return_value=0, - ) - mocker.patch( - "vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa - return_value=[torch.Tensor([10]), - torch.Tensor([2])], - ) - input_parallel = row_parallel_linear.calc_input(input_tensor) - assert torch.allclose(input_parallel, expected) - - @pytest.mark.parametrize( - "reduce_results, tp_size, expected", - [ - (True, 2, torch.tensor(56.0)), - (True, 1, torch.tensor(14.0)), - (False, 2, torch.tensor(14.0)), - ], - ) - def test_calc_output( - self, - reduce_results, - tp_size, - expected, - mocker: MockerFixture, - ): - quant_method = mocker.MagicMock() - quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa - torch.tensor([1.0, 2.0])) - row_parallel_linear = self.init_row_parallel_linear(mocker) - row_parallel_linear.__dict__["reduce_results"] = reduce_results - row_parallel_linear.__dict__["tp_size"] = tp_size - row_parallel_linear.__dict__["quant_method"] = quant_method - row_parallel_linear.__dict__["tp_rank"] = 0 - row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa - - mocker.patch( - "vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group", - return_value=mocker.MagicMock(device_group=mocker.MagicMock()), - ) - mocker.patch( - "torch_npu.npu_mm_all_reduce_base", - side_effect=lambda input_, weight, hccl_info, bias: input_. - matmul( # noqa - torch.tensor([4.0, 8.0])), - ) # noqa - ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0])) - assert torch.allclose(ret, expected) - - def test_enable_allreduce_matmul(self, mocker: MockerFixture): - mocker.patch.object(envs_ascend, - "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", - new=True) - reload(patch_linear) - assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE - assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id( - patch_linear.AscendRowParallelLinear) diff --git a/tests/ut/quantization/test_func_wrapper.py b/tests/ut/quantization/test_func_wrapper.py deleted file mode 100644 index 5020f80..0000000 --- a/tests/ut/quantization/test_func_wrapper.py +++ /dev/null @@ -1,134 +0,0 @@ -from unittest.mock import patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot, - wrapper_rmsnorm_init) - - -class MockRMSNorm: - - def __init__(self, hidden_size: int, **extra_args): - self.hidden_size = hidden_size - self.weight = torch.ones(hidden_size) - self.input_scale = 1.0 - self.input_offset = 0.0 - self.variance_epsilon = 1e-6 - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - self.ignore_anti = extra_args.get('ignore_anti', True) - - -class TestFuncWrapper(TestBase): - - def test_wrapper_rmsnorm_init(self): - - @wrapper_rmsnorm_init - def init(self, hidden_size: int, **extra_args) -> None: - self.hidden_size = hidden_size - - hidden_size = 128 - extra_args = {'arg1': 'value1'} - - rms_norm = MockRMSNorm(hidden_size, **extra_args) - init(rms_norm, hidden_size, **extra_args) - - self.assertTrue(hasattr(rms_norm, 'ignore_anti')) - self.assertTrue(rms_norm.ignore_anti) - - self.assertTrue(hasattr(rms_norm, 'bias')) - self.assertIsInstance(rms_norm.bias, torch.nn.Parameter) - self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size])) - self.assertFalse(rms_norm.bias.requires_grad) - - @patch('torch_npu._npu_quant_rms_norm') - def test_wrapper_rmsnorm_forward_oot_with_residual( - self, mock_npu_quant_rms_norm): - hidden_size = 128 - x = torch.randn(hidden_size) - residual = torch.randn(hidden_size) - expected_out = torch.randn(hidden_size) - - mock_npu_quant_rms_norm.return_value = (expected_out, residual) - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x, residual - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = False - - output, res = forward_oot(rms_norm, x, residual) - - mock_npu_quant_rms_norm.assert_called_once() - - args, kwargs = mock_npu_quant_rms_norm.call_args - self.assertTrue(torch.equal(args[1], rms_norm.weight)) - self.assertTrue(torch.equal(args[2], rms_norm.bias)) - self.assertEqual(args[3], rms_norm.input_scale) - self.assertEqual(args[4], rms_norm.input_offset) - self.assertEqual(args[5], rms_norm.variance_epsilon) - self.assertTrue(torch.equal(res, residual)) - - @patch('torch_npu._npu_quant_rms_norm') - def test_wrapper_rmsnorm_forward_oot_without_residual( - self, mock_npu_quant_rms_norm): - hidden_size = 128 - x = torch.randn(hidden_size) - expected_out = torch.randn(hidden_size) - - mock_npu_quant_rms_norm.return_value = expected_out - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = False - - output = forward_oot(rms_norm, x) - - mock_npu_quant_rms_norm.assert_called_once() - - args, kwargs = mock_npu_quant_rms_norm.call_args - self.assertTrue(torch.equal(args[0], x)) - self.assertTrue(torch.equal(args[1], rms_norm.weight)) - self.assertTrue(torch.equal(args[2], rms_norm.bias)) - self.assertEqual(args[3], rms_norm.input_scale) - self.assertEqual(args[4], rms_norm.input_offset) - self.assertEqual(args[5], rms_norm.variance_epsilon) - - self.assertTrue(torch.equal(output, expected_out)) - - def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self): - hidden_size = 128 - x = torch.randn(hidden_size) - residual = torch.randn(hidden_size) - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x, residual - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = True - - output, res = forward_oot(rms_norm, x, residual) - - self.assertTrue(torch.equal(output, x.add_(rms_norm.bias))) - self.assertTrue(torch.equal(res, residual)) - - def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self): - hidden_size = 128 - x = torch.randn(hidden_size) - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = True - - output = forward_oot(rms_norm, x) - - self.assertTrue(torch.equal(output, x.add_(rms_norm.bias))) diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index 7529fea..5a119b4 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -73,9 +73,12 @@ class TestAscendQuantConfig(TestBase): self.assertIsNone(result) def test_get_quant_method_for_linear(self): + mock_config = MagicMock() + mock_config.model_config.hf_config.model_type = None linear_layer = MagicMock(spec=LinearBase) # Test skipped layer - with patch.object(self.ascend_config, + with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ + patch.object(self.ascend_config, \ 'is_layer_skipped_ascend', return_value=True): method = self.ascend_config.get_quant_method(linear_layer, ".attn") @@ -83,6 +86,7 @@ class TestAscendQuantConfig(TestBase): # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ + patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: method = self.ascend_config.get_quant_method(linear_layer, ".attn") @@ -93,14 +97,18 @@ class TestAscendQuantConfig(TestBase): def test_get_quant_method_for_attention(self): attention_layer = MagicMock(spec=Attention) - with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', + mock_config = MagicMock() + mock_config.model_config.hf_config.model_type = None + with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ + patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \ return_value=MagicMock()) as mock_ascend_kvcache: # Test with fa_quant_type method = self.ascend_config.get_quant_method( attention_layer, ".attn") self.assertIs(method, mock_ascend_kvcache.return_value) - with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', + with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ + patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \ return_value=MagicMock()) as mock_ascend_kvcache: # Test with kv_quant_type modified_config = {"kv_quant_type": "C8"} @@ -113,9 +121,12 @@ class TestAscendQuantConfig(TestBase): fused_moe_layer = MagicMock(spec=FusedMoE) fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig) fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig) + mock_config = MagicMock() + mock_config.model_config.hf_config.model_type = None # Test skipped layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \ + patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") @@ -123,6 +134,7 @@ class TestAscendQuantConfig(TestBase): # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ + patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") @@ -156,33 +168,22 @@ class TestAscendKVCacheMethod(TestBase): def setUp(self): # Setup common test fixtures self.mock_quant_config = MagicMock(spec=AscendQuantConfig) - self.mock_quant_config.quant_description = {"some_config": "value"} - self.prefix = "attention_layer" + self.mock_quant_config.quant_description = {"kv_quant_type": "C8"} + self.prefix = "layer.attn" - # Mock the quantizer and quant_method - self.mock_quantizer = MagicMock() + # Mock quant_method self.mock_quant_method = MagicMock() - - # Patch the AscendQuantizer - self.quantizer_patcher = patch( - 'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer', - return_value=self.mock_quantizer) - self.mock_get_quantizer = self.quantizer_patcher.start() - - self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method + self.patcher = patch( + 'vllm_ascend.quantization.quant_config.get_quant_method') + self.mock_get_quant_method = self.patcher.start() + self.mock_get_quant_method.return_value = self.mock_quant_method # Create instance self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config, self.prefix) def tearDown(self): - self.quantizer_patcher.stop() - - def test_init(self): - """Test initialization with proper quantizer setup.""" - self.mock_get_quantizer.assert_called_once_with( - self.mock_quant_config.quant_description, self.prefix) - self.mock_quantizer.build_attention_method.assert_called_once() + self.patcher.stop() def test_create_weights(self): """Test create_weights delegates to quant_method.""" diff --git a/tests/ut/quantization/test_quantizer.py b/tests/ut/quantization/test_quantizer.py deleted file mode 100644 index a51faee..0000000 --- a/tests/ut/quantization/test_quantizer.py +++ /dev/null @@ -1,145 +0,0 @@ -from unittest.mock import MagicMock, patch - -from tests.ut.base import TestBase -from vllm_ascend.quantization.quant_config import AscendQuantConfig -from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer, - W4A8DYNAMICQuantizer, - W8A8Quantizer) - -SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"} - - -class TestGetQuantizer(TestBase): - - def setUp(self): - # Setup common test fixtures - self.supported_types = { - 'INT8': MagicMock(_instance=None), - 'FP16': MagicMock(_instance=None), - 'C8': MagicMock(_instance=None) - } - self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy() - SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types) - self.mock_quant_config = MagicMock(spec=AscendQuantConfig) - self.mock_quant_config.quant_description = {"some_config": "value"} - - def tearDown(self): - # Restore original supported types - SUPPORT_ASCEND_QUANTIZER_TYPE.clear() - SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types) - - def test_get_quantizer_fa(self): - """Test successful quantizer retrieval for different cases.""" - # Setup - quant_description = {'fa_quant_type': 'C8'} - prefix = '.attn' - expected_type = 'C8' - with patch.dict( - 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', - SUPPORT_ASCEND_QUANTIZER_TYPE): - - result = VLLMAscendQuantizer.get_quantizer( - quant_description, - prefix, - packed_modules_mapping={"some": "mapping"}) - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result, - self.supported_types[expected_type]._instance) - self.supported_types[expected_type].assert_called_once_with( - quant_description) - - def test_get_quantizer_kv(self): - """Test successful quantizer retrieval for different cases.""" - # Setup - quant_description = {'kv_quant_type': 'C8'} - prefix = '.attn' - expected_type = 'C8' - with patch.dict( - 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', - SUPPORT_ASCEND_QUANTIZER_TYPE): - - result = VLLMAscendQuantizer.get_quantizer( - quant_description, - prefix, - packed_modules_mapping={"some": "mapping"}) - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result, - self.supported_types[expected_type]._instance) - self.supported_types[expected_type].assert_called_once_with( - quant_description) - - def test_get_quantizer_linear(self): - """Test successful quantizer retrieval for different cases.""" - # Setup - quant_description = {'linear_type': 'INT8'} - prefix = 'nothing' - expected_type = 'INT8' - with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type', - return_value=expected_type), \ - patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE): - - result = VLLMAscendQuantizer.get_quantizer( - quant_description, - prefix, - packed_modules_mapping={"some": "mapping"}) - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result, - self.supported_types[expected_type]._instance) - self.supported_types[expected_type].assert_called_once_with( - quant_description) - - -class TestW8A8Quantizer(TestBase): - - def setUp(self): - self.quantizer = W8A8Quantizer(quant_description={}) - - def test_build_linear_method(self): - with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_linear_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - def test_build_moe_method(self): - with patch( - 'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_moe_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - def test_build_attention_method(self): - with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_attention_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - -class TestW4A8DYNAMICQuantizer(TestBase): - - def setUp(self): - self.quantizer = W4A8DYNAMICQuantizer(quant_description={}) - - def test_build_linear_method(self): - with patch( - 'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_linear_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - def test_build_moe_method(self): - with patch( - 'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod', - return_value=MagicMock()) as mock_fused_moe: - result = self.quantizer.build_moe_method() - mock_fused_moe.assert_called_once_with() - self.assertIsInstance(result, MagicMock) diff --git a/tests/ut/quantization/test_utils.py b/tests/ut/quantization/test_utils.py new file mode 100644 index 0000000..153089a --- /dev/null +++ b/tests/ut/quantization/test_utils.py @@ -0,0 +1,62 @@ +import types + +from tests.ut.base import TestBase +from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP, + get_quant_method) + + +class TestGetQuantMethod(TestBase): + + def setUp(self): + self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy( + ) + for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): + for layer_type in layer_map.keys(): + ASCEND_QUANTIZATION_METHOD_MAP[quant_type][ + layer_type] = types.new_class(f"{quant_type}_{layer_type}") + + def tearDown(self): + # Restore original map + ASCEND_QUANTIZATION_METHOD_MAP.clear() + ASCEND_QUANTIZATION_METHOD_MAP.update( + self.original_quantization_method_map) + + def test_linear_quant_methods(self): + for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): + if "linear" in layer_map.keys(): + prefix = "linear_layer" + cls = layer_map["linear"] + method = get_quant_method({"linear_layer.weight": quant_type}, + prefix, "linear") + self.assertIsInstance(method, cls) + + def test_moe_quant_methods(self): + for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): + if "moe" in layer_map.keys(): + prefix = "layer" + cls = layer_map["moe"] + method = get_quant_method({"layer.weight": quant_type}, prefix, + "moe") + self.assertIsInstance(method, cls) + + def test_with_fa_quant_type(self): + quant_description = {"fa_quant_type": "C8"} + method = get_quant_method(quant_description, ".attn", "attention") + self.assertIsInstance( + method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"]) + + def test_with_kv_quant_type(self): + quant_description = {"kv_quant_type": "C8"} + method = get_quant_method(quant_description, ".attn", "attention") + self.assertIsInstance( + method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"]) + + def test_invalid_layer_type(self): + quant_description = {"linear_layer.weight": "W8A8"} + with self.assertRaises(NotImplementedError): + get_quant_method(quant_description, "linear_layer", "unsupported") + + def test_invalid_quant_type(self): + quant_description = {"linear_layer.weight": "UNKNOWN"} + with self.assertRaises(NotImplementedError): + get_quant_method(quant_description, "linear_layer", "linear") diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index d7fdf82..a14702b 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -1,4 +1,3 @@ -import copy from unittest.mock import Mock, patch import torch @@ -11,8 +10,19 @@ from vllm_ascend.quantization.w4a8_dynamic import ( class TestAscendW4A8DynamicLinearMethod(TestBase): def setUp(self): - self.method = AscendW4A8DynamicLinearMethod() - self.method.group_size = 8 + with patch( + 'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config' + ) as mock_get_current_vllm_config: + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 256}) + mock_vllm_config.scheduler_config = Mock( + max_num_batched_tokens=2048, + max_model_len=2048, + enable_chunked_prefill=False) + mock_get_current_vllm_config.return_value = mock_vllm_config + self.method = AscendW4A8DynamicLinearMethod() + self.method.group_size = 8 def test_get_weight(self): weight = self.method.get_weight(8, 32, torch.bfloat16) @@ -37,18 +47,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): output_size = 56 group_size = 2 + @patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config') @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') @patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group') @patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group') @patch('torch.distributed.get_rank', return_value=0) def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group, - get_current_vllm_config): + get_current_vllm_config, mock_get_ascend_config): + # Mock ascend config + mock_ascend_config = Mock() + mock_ascend_config.dynamic_eplb = False + mock_get_ascend_config.return_value = mock_ascend_config + mock_vllm_config = Mock() mock_vllm_config.quant_config = Mock(quant_description={ "group_size": self.group_size, "version": "0.0.0" }) mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True) + mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048, + max_model_len=2048, + enable_chunked_prefill=False) get_current_vllm_config.return_value = mock_vllm_config self.quant_method = AscendW4A8DynamicFusedMoEMethod() @@ -75,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): # old quant version weight param_dict = self.quant_method.get_dynamic_quant_param( self.experts, self.input_size, self.output_size, torch.bfloat16) - self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w13_weight_scale"].shape, (self.experts, 2 * self.input_size, 1)) self.assertEqual(param_dict["w13_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w13_weight_scale_second"].shape, (self.experts, 2 * self.input_size, self.output_size // self.group_size)) - self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w2_weight_scale"].shape, (self.experts, self.output_size, 1)) self.assertEqual(param_dict["w2_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w2_weight_scale_second"].shape, (self.experts, self.output_size, self.input_size // self.group_size)) @@ -99,40 +118,87 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual( param_dict["w2_scale_bias"].shape, (self.experts, self.output_size, 16 // self.quant_method.tp_size)) + # per-channel weight + self.quant_method.is_per_channel_weight = True + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + pergroup_param = [ + "w13_weight_scale_second", "w13_weight_offset_second", + "w2_weight_scale_second", "w2_weight_offset_second" + ] + is_contains = any(key in param_dict for key in pergroup_param) + self.assertFalse(is_contains) + def build_layer(self, + is_new_quant_version=True, + is_per_channel_weight=False): + layer = torch.nn.Module() + if is_new_quant_version: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size // 2, self.input_size), + dtype=torch.int8), + requires_grad=False) + w13_scale_bias = torch.zeros( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32) + layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, + requires_grad=False) + w2_scale_bias = torch.zeros((self.experts, self.output_size, + 16 // self.quant_method.tp_size), + dtype=torch.float32) + layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, + requires_grad=False) + else: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, 2 * self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size, self.input_size), + dtype=torch.int8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, self.output_size, 1), dtype=torch.float32), + requires_grad=False) + if not is_per_channel_weight: + layer.w13_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, 2 * self.input_size, + self.output_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w13_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w13_weight_scale_second.data), + requires_grad=False) + layer.w2_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, self.output_size, + self.input_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w2_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w2_weight_scale_second.data), + requires_grad=False) + return layer + + @patch('torch_npu.npu_format_cast') @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') - def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): - # old quant version weight - layer = torch.nn.Module() - layer.w13_weight = torch.nn.Parameter(torch.zeros( - (self.experts, 2 * self.input_size, self.output_size), - dtype=torch.int8), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.output_size, self.input_size), - dtype=torch.int8), - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, - self.output_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, - self.input_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - new_layer = copy.deepcopy(layer) - + def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize, + mock_npu_format_cast): mock_npu.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor() + + def func_by_args(weight, num_format): + return weight + + mock_npu_format_cast.side_effect = func_by_args + # old quant version weight + layer = self.build_layer(is_new_quant_version=False) self.quant_method.process_weights_after_loading(layer) self.assertTrue(hasattr(layer, "w13_scale_bias")) self.assertEqual(layer.w13_scale_bias.data.shape, @@ -144,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) # new quant version weight self.quant_method.new_quant_version = True - new_layer.w13_weight.data = torch.zeros( - (self.experts, self.input_size, self.output_size), - dtype=torch.int8) - new_layer.w2_weight.data = torch.zeros( - (self.experts, self.output_size // 2, self.input_size), - dtype=torch.int8) - w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), - dtype=torch.float32) - new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, - requires_grad=False) - w2_scale_bias = torch.zeros( - (self.experts, self.output_size, 16 // self.quant_method.tp_size), - dtype=torch.float32) - new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, - requires_grad=False) + new_layer = self.build_layer(is_new_quant_version=True) self.quant_method.process_weights_after_loading(new_layer) self.assertEqual(new_layer.w13_scale_bias.data.shape, (self.experts, 2 * self.input_size)) self.assertEqual(new_layer.w2_scale_bias.data.shape, (self.experts, self.output_size)) + self.assertFalse(hasattr(new_layer, "w13_weight_scale_second")) + # per-channel weight + self.quant_method.is_per_channel_weight = True + per_channel_layer = self.build_layer(is_new_quant_version=True, + is_per_channel_weight=True) + self.quant_method.process_weights_after_loading(per_channel_layer) + self.assertEqual(new_layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 90a5f59..3f2557b 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -5,8 +5,8 @@ import torch from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk, - select_experts) +from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk, + select_experts) from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod, @@ -784,7 +784,7 @@ class TestSelectExperts(TestBase): self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk') + @patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk') def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): """Test grouped topk with expert score correction bias""" mock_grouped_topk.return_value = torch.ones(self.num_tokens, diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py new file mode 100644 index 0000000..f25192c --- /dev/null +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -0,0 +1,69 @@ +from unittest.mock import Mock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod + + +class TestAscendW8A8FusedMoEMethod(TestBase): + num_experts = 8 + hidden_size = 128 + intermediate_size = 128 + + @patch("torch.distributed.get_rank") + @patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group") + @patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config") + @patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group") + def setUp(self, mock_get_ep_group, mock_get_ascend_config, + mock_get_mc2_group, mock_get_rank): + with patch( + 'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config' + ) as mock_get_current_vllm_config: + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 256}) + mock_vllm_config.scheduler_config = Mock( + max_num_batched_tokens=2048, + max_model_len=2048, + enable_chunked_prefill=False) + mock_get_current_vllm_config.return_value = mock_vllm_config + mock_ep_group = Mock() + mock_get_ep_group.return_value = mock_ep_group + mock_ascend_config = Mock() + + # 创建一个具有具体属性的 Mock 对象来表示 ascend_scheduler_config + mock_ascend_scheduler_config = Mock() + mock_ascend_scheduler_config.enabled = False + mock_ascend_scheduler_config.max_num_batched_tokens = 1024 + mock_ascend_scheduler_config.max_model_len = 2048 + mock_ascend_config.ascend_scheduler_config = mock_ascend_scheduler_config + + mock_ascend_config.torchair_graph_config = Mock(enabled=False) + mock_ascend_config.enable_chunked_prefill = False + mock_get_ascend_config.return_value = mock_ascend_config + mock_mc2_group = Mock(device_group=0) + mock_get_mc2_group.return_value = mock_mc2_group + mock_rank = Mock() + mock_get_rank.return_value = mock_rank + + self.quant_method = AscendW8A8DynamicFusedMoEMethod() + + def test_get_weight(self): + param_dict = self.quant_method.get_weight(self.num_experts, + self.intermediate_size, + self.hidden_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) + self.assertEqual( + param_dict["w13_weight"].shape, + (self.num_experts, 2 * self.intermediate_size, self.hidden_size)) + + def test_get_dynamic_quant_param(self): + param_dict = self.quant_method.get_dynamic_quant_param( + self.num_experts, self.intermediate_size, self.hidden_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].shape, + (self.num_experts, 2 * self.intermediate_size, 1)) diff --git a/tests/ut/sample/logits_processor/test_builtin.py b/tests/ut/sample/logits_processor/test_builtin.py new file mode 100644 index 0000000..cecd186 --- /dev/null +++ b/tests/ut/sample/logits_processor/test_builtin.py @@ -0,0 +1,40 @@ +import torch +from pytest_mock import MockerFixture +from vllm.config import SchedulerConfig, VllmConfig + +from tests.ut.base import PytestBase +from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor + + +class TestMinPLogitsProcessorInitFunc(PytestBase): + + def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture): + device_cpu = torch.device("cpu") + device_npu = torch.device("npu") + is_pin_memory = False + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) + mock_scheduler_config.decode_max_num_seqs = 0 + mock_scheduler_config.max_num_seqs = 128 + mock_vllm_config.scheduler_config = mock_scheduler_config + # torch.zeros/torch.empty returns error on online ut machine, so mock it + mock_tensor = torch.zeros((256, ), + dtype=torch.float32, + pin_memory=False) + mocker.patch("torch.zeros", return_value=mock_tensor) + mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) + mocker.patch("torch.empty", return_value=mock_empty_tensor) + + processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu, + is_pin_memory) + + assert processor_cpu.min_p is not None + assert processor_cpu.use_double_tensor is False + assert processor_cpu.min_p_cpu.shape[0] == 256 + + processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu, + is_pin_memory) + + assert processor_cpu.min_p is not None + assert processor_cpu.use_double_tensor is True + assert processor_cpu.min_p_cpu.shape[0] == 256 diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 4c7cfa6..4d3de7f 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -43,6 +43,7 @@ class TestAscendConfig(TestBase): # No additional config given, check the default value here. ascend_config = init_ascend_config(test_vllm_config) self.assertIsNone(ascend_config.expert_map_path) + self.assertFalse(ascend_config.multistream_overlap_shared_expert) torchair_graph_config = ascend_config.torchair_graph_config self.assertFalse(torchair_graph_config.enabled) @@ -51,8 +52,8 @@ class TestAscendConfig(TestBase): self.assertEqual(torchair_graph_config.graph_batch_sizes, []) self.assertFalse(torchair_graph_config.graph_batch_sizes_init) self.assertFalse(torchair_graph_config.enable_multistream_mla) - self.assertFalse(torchair_graph_config.enable_multistream_moe) self.assertTrue(torchair_graph_config.enable_view_optimize) + self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertFalse(torchair_graph_config.enable_kv_nz) ascend_scheduler_config = ascend_config.ascend_scheduler_config @@ -68,10 +69,11 @@ class TestAscendConfig(TestBase): "graph_batch_sizes": [1, 2, 4], "graph_batch_sizes_init": False, "enable_multistream_mla": True, - "enable_multistream_moe": True, "enable_view_optimize": True, + "enable_frozen_parameter": True, "enable_kv_nz": True }, + "multistream_overlap_shared_expert": True, "ascend_scheduler_config": { "enabled": True }, @@ -80,6 +82,7 @@ class TestAscendConfig(TestBase): } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") + self.assertTrue(ascend_config.multistream_overlap_shared_expert) torchair_graph_config = ascend_config.torchair_graph_config self.assertTrue(torchair_graph_config.enabled) @@ -87,8 +90,8 @@ class TestAscendConfig(TestBase): self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4]) self.assertFalse(torchair_graph_config.graph_batch_sizes_init) self.assertTrue(torchair_graph_config.enable_multistream_mla) - self.assertTrue(torchair_graph_config.enable_multistream_moe) self.assertTrue(torchair_graph_config.enable_view_optimize) + self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertTrue(torchair_graph_config.enable_kv_nz) ascend_scheduler_config = ascend_config.ascend_scheduler_config @@ -215,21 +218,6 @@ class TestAscendConfig(TestBase): test_vllm_config.model_config = fake_model_config init_ascend_config(test_vllm_config) check_ascend_config(test_vllm_config, False) - # aclgraph + deepseek model - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "refresh": True - } - model_path = os.path.join(os.path.dirname(__file__), "fake_weight") - fake_model_config = ModelConfig(model=model_path) - fake_model_config.hf_config = PretrainedConfig() - fake_model_config.hf_config.model_type = "deepseek" - test_vllm_config.model_config = fake_model_config - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) def test_check_torchair_supported(self): test_cases = [('deepseek_v3', True), ('PanguProMoE', True), @@ -318,17 +306,6 @@ class TestAscendConfig(TestBase): } init_ascend_config(test_vllm_config) - # enable_multistream_moe should not be enabled without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "enable_multistream_moe": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - # mode should not be configured without torchair graph mode with self.assertRaises(RuntimeError): test_vllm_config.additional_config = { @@ -359,3 +336,27 @@ class TestAscendConfig(TestBase): test_vllm_config.parallel_config = ParallelConfig( data_parallel_size=4, tensor_parallel_size=2) init_ascend_config(test_vllm_config) + + with self.assertRaises(AssertionError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + "oproj_tensor_parallel_size": 2, + "refresh": True + } + test_vllm_config.parallel_config = ParallelConfig( + data_parallel_size=4, tensor_parallel_size=2) + init_ascend_config(test_vllm_config) + + with self.assertRaises(AssertionError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + }, + "oproj_tensor_parallel_size": 2, + "refresh": True + } + test_vllm_config.parallel_config = ParallelConfig( + data_parallel_size=4, tensor_parallel_size=1) + init_ascend_config(test_vllm_config) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index de8b9be..60f0172 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -36,6 +36,7 @@ class TestNPUPlatform(TestBase): mock_ascend_config = MagicMock() mock_ascend_config.torchair_graph_config.enabled = False mock_ascend_config.ascend_scheduler_config.enabled = False + mock_ascend_config.enable_shared_expert_dp = False return mock_ascend_config def setUp(self): @@ -363,36 +364,6 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) - @patch("vllm_ascend.utils.is_310p", return_value=False) - @patch("vllm_ascend.ascend_config.check_ascend_config") - @patch("vllm_ascend.ascend_config.init_ascend_config") - def test_check_and_update_config_disable_aclgraph_when_ray_enabled( - self, mock_init_ascend, mock_check_ascend, mock_is_310p): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) - vllm_config = TestNPUPlatform.mock_vllm_config() - vllm_config.model_config.enforce_eager = False - vllm_config.compilation_config.level = CompilationLevel.PIECEWISE - vllm_config.parallel_config.distributed_executor_backend = "ray" - - with self.assertLogs(logger="vllm", level="WARNING") as cm: - from vllm_ascend import platform - - importlib.reload(platform) - self.platform.check_and_update_config(vllm_config) - print(30 * "=", f"cm.output: {cm.output}") - self.assertTrue( - "Ray distributed executor backend is not compatible with ACL Graph mode" - in cm.output[0]) - self.assertEqual( - vllm_config.compilation_config.level, - CompilationLevel.NO_COMPILATION, - ) - self.assertEqual( - vllm_config.compilation_config.cudagraph_mode, - CUDAGraphMode.NONE, - ) - @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") @@ -509,6 +480,7 @@ class TestNPUPlatform(TestBase): def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False + mock_config.enable_shared_expert_dp = False mock_get_ascend_config.return_value = mock_config @@ -589,9 +561,8 @@ class TestNPUPlatform(TestBase): def test_get_punica_wrapper(self): result = self.platform.get_punica_wrapper() - self.assertEqual( - result, - "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU") + self.assertEqual(result, + "vllm_ascend.lora.punica_npu.PunicaWrapperNPU") @patch("torch.npu.reset_peak_memory_stats") @patch("torch.npu.max_memory_allocated") diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 0d264c7..b2b3c32 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -24,6 +24,7 @@ from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig, from tests.ut.base import TestBase from vllm_ascend import utils +from vllm_ascend.utils import REGISTERED_ASCEND_OPS class TestUtils(TestBase): @@ -259,8 +260,22 @@ class TestUtils(TestBase): utils.update_aclgraph_sizes(test_vllm_config) del os.environ['HCCL_OP_EXPANSION_MODE'] self.assertEqual( - 147, + 137, len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) + + test_vllm_config.speculative_config = mock.MagicMock() + test_vllm_config.speculative_config.draft_model_config = mock.MagicMock( + ) + test_vllm_config.speculative_config.draft_model_config.hf_config = mock.MagicMock( + ) + test_vllm_config.speculative_config.draft_model_config.hf_config.num_hidden_layers = 2 + os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV' + utils.update_aclgraph_sizes(test_vllm_config) + del os.environ['HCCL_OP_EXPANSION_MODE'] + self.assertEqual( + 111, + len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) + # max_num_batch_sizes >= len(original_sizes) test_compilation_config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 3]) @@ -288,14 +303,14 @@ class TestUtils(TestBase): # ascend custom op is not registered utils.register_ascend_customop() - # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 12) + self.assertEqual(mock_customop.register_oot.call_count, + len(REGISTERED_ASCEND_OPS)) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() - # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 12) + self.assertEqual(mock_customop.register_oot.call_count, + len(REGISTERED_ASCEND_OPS)) class TestProfileExecuteDuration(TestBase): diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py index 7aafdfc..7feeba9 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -165,8 +165,6 @@ class TestTorchairDeepSeekMTP(PytestBase): mocker.patch( "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__", return_value=None) - mocker.patch("vllm.model_executor.layers.sampler.get_sampler", - return_value=None) mocker.patch( "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", return_value=None) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py index e72d023..5a7c2a2 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py @@ -100,6 +100,11 @@ def mock_distributed(): pp_group.rank_in_group = 0 pp_group.world_size = 1 + mlp_tp_group = Mock(spec=GroupCoordinator) + mlp_tp_group.rank_in_group = 0 + mlp_tp_group.world_size = 1 + mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128)) + mock_vllm_config = Mock() mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) @@ -196,10 +201,6 @@ def test_torchair_deepseek_v2_mlp(mock_distributed, base_config): quant_config=None) assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul) - x = torch.randn(2, 4, 128) - output = mlp(x) - assert output.shape == (2, 4, 128) - with patch( "vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig" ) as mock_quant_config: @@ -274,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, @patch("torch_npu.npu_add_rms_norm") @patch("torch_npu.npu_rms_norm") -def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm, +@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) +@patch("torch.ops.vllm.maybe_chunk_residual", + side_effect=lambda x, residual: residual) +def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual, + mock_maybe_wait_prefetch_done, + mock_rms_norm, mock_add_norm, mock_distributed, base_config, vllm_config): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index 19df5dc..fb1cd81 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -24,10 +24,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm_ascend.ascend_forward_context import _get_fused_moe_state from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod -from vllm_ascend.quantization.quantizer import W8A8Quantizer from vllm_ascend.torchair.ops.torchair_fused_moe import ( TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod) -from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 +from vllm_ascend.utils import adapt_patch # noqa E402 +from vllm_ascend.utils import AscendSocVersion, vllm_version_is adapt_patch(True) @@ -54,6 +54,10 @@ def mock_dp_and_tp_group(mocker): @pytest.fixture def mock_dist_env(mocker: MockerFixture): # init dist env patch + if vllm_version_is("0.10.2"): + dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10]) + else: + dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ @@ -67,13 +71,13 @@ def mock_dist_env(mocker: MockerFixture): patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce', return_value=torch.randn(5, 32)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.data_parallel_reduce_scatter', - return_value=torch.randn(5, 32)), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config', return_value=MagicMock( - torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False), + torchair_graph_config=MagicMock(enabled=False), + enable_multistream_moe=False, + enable_shared_expert_dp=False, expert_map_path=None )), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map', @@ -81,7 +85,7 @@ def mock_dist_env(mocker: MockerFixture): patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context', return_value=MagicMock( max_tokens_across_dp=10, - dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]) + dp_metadata=dp_metadata, )), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config', return_value=MagicMock( @@ -154,6 +158,8 @@ def default_moe_config(): def moe_method(mock_dist_env): moe = MagicMock() moe.moe_parallel_config.return_value = MagicMock(ep_size=4) + moe.moe_parallel_config.use_ep = False + moe.moe_parallel_config.dp_size = 1 return TorchairAscendUnquantizedFusedMoEMethod(moe) @@ -199,6 +205,9 @@ class MockFusedMoEMethod(FusedMoEMethodBase): expert_weights: torch.Tensor) -> torch.Tensor: pass + def get_fused_moe_quant_config(self, layer: torch.nn.Module): + pass + class TestTorchairAscendFusedMoe: @@ -236,12 +245,9 @@ class TestTorchairAscendFusedMoe: mock_quant_method = MockFusedMoEMethod() mock_quant_config.get_quant_method.return_value = mock_quant_method mock_quant_config.is_layer_skipped_ascend.return_value = False - with patch( - 'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer', - return_value=W8A8Quantizer): + with patch("vllm_ascend.quantization.quant_config.get_quant_method"): moe = TorchairAscendFusedMoE(**default_moe_config, quant_config=mock_quant_config) - assert moe.quant_method is not None assert isinstance(moe.quant_method, AscendFusedMoEMethod) diff --git a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py index e7c68f7..4adb598 100644 --- a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py +++ b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py @@ -5,8 +5,9 @@ import torch from tests.ut.base import TestBase from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( - custom_rotary_embedding_enabled, native_rope_deepseek_forward, - rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale) + _set_cos_sin_cache, custom_rotary_embedding_enabled, + native_rope_deepseek_forward, rope_forward_oot, rotate_half, + yarn_find_correction_dim, yarn_get_mscale) class TestCustomRotaryEmbeddingEnabled(TestBase): @@ -103,7 +104,7 @@ class TestRopeForwardOot(TestBase): self.assertTrue(torch.equal(result_q, self.query)) self.assertTrue(torch.equal(result_k, self.key)) - @patch('torch.ops._C') + @patch('torch.ops._C_ascend') @patch( 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') @patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p', @@ -200,6 +201,28 @@ class MockRopeModule: self.sin_cached = None self.rotary_dim = 1 self.base = 1 + self.beta_fast = 32 + self.beta_slow = 1 + self.max_position_embeddings = 4096 + self.mscale = 1.0 + self.scaling_factor = 40 + + def register_buffer(self): + pass + + +class TestSetSinCosCache(TestBase): + + def test_set_cos_sin_cache(self): + module = MockRopeModule() + + with patch.object(module, "register_buffer") as mock_register_buffer: + _set_cos_sin_cache(module, + 1024, + device="cpu", + dtype=torch.bfloat16) + + mock_register_buffer.assert_called() class TestNativeRopeDeepseekForward(TestBase): @@ -220,30 +243,6 @@ class TestNativeRopeDeepseekForward(TestBase): assert q_pe.shape == query.shape assert k_pe.shape == key.shape - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache' - ) - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_cache_handling( - self, mock_rope_forward_oot, mock_set_cache): - # Test cache situation is true - module = MockRopeModule(max_seq_len=1024) - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, - positions, - query, - key, - max_seq_len=2048) - - assert q_pe.shape == query.shape - assert k_pe.shape == key.shape - @patch( 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') def test_native_rope_deepseek_forward_key_reshaping( diff --git a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py index cd94101..9fd3f29 100644 --- a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py +++ b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py @@ -1,4 +1,3 @@ -import copy from unittest.mock import Mock, patch import torch @@ -85,19 +84,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): # old quant version weight param_dict = self.quant_method.get_dynamic_quant_param( self.experts, self.input_size, self.output_size, torch.bfloat16) - self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w13_weight_scale"].shape, (self.experts, 2 * self.input_size, 1)) self.assertEqual(param_dict["w13_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w13_weight_scale_second"].shape, (self.experts, 2 * self.input_size, self.output_size // self.group_size)) - self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w2_weight_scale"].shape, (self.experts, self.output_size, 1)) self.assertEqual(param_dict["w2_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w2_weight_scale_second"].shape, (self.experts, self.output_size, self.input_size // self.group_size)) @@ -109,40 +108,80 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual( param_dict["w2_scale_bias"].shape, (self.experts, self.output_size, 16 // self.quant_method.tp_size)) + # per-channel weight + self.quant_method.is_per_channel_weight = True + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + pergroup_param = [ + "w13_weight_scale_second", "w13_weight_offset_second", + "w2_weight_scale_second", "w2_weight_offset_second" + ] + is_contains = any(key in param_dict for key in pergroup_param) + self.assertFalse(is_contains) + + def build_layer(self, + is_new_quant_version=True, + is_per_channel_weight=False): + layer = torch.nn.Module() + if is_new_quant_version: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size // 2, self.input_size), + dtype=torch.int8), + requires_grad=False) + w13_scale_bias = torch.zeros( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32) + layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, + requires_grad=False) + w2_scale_bias = torch.zeros((self.experts, self.output_size, + 16 // self.quant_method.tp_size), + dtype=torch.float32) + layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, + requires_grad=False) + else: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, 2 * self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size, self.input_size), + dtype=torch.int8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, self.output_size, 1), dtype=torch.float32), + requires_grad=False) + if not is_per_channel_weight: + layer.w13_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, 2 * self.input_size, + self.output_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w13_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w13_weight_scale_second.data), + requires_grad=False) + layer.w2_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, self.output_size, + self.input_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w2_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w2_weight_scale_second.data), + requires_grad=False) + return layer @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): - # old quant version weight - layer = torch.nn.Module() - layer.w13_weight = torch.nn.Parameter(torch.zeros( - (self.experts, 2 * self.input_size, self.output_size), - dtype=torch.int8), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.output_size, self.input_size), - dtype=torch.int8), - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, - self.output_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, - self.input_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - new_layer = copy.deepcopy(layer) - mock_npu.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor() + # old quant version weight + layer = self.build_layer(is_new_quant_version=False) self.quant_method.process_weights_after_loading(layer) self.assertTrue(hasattr(layer, "w13_scale_bias")) self.assertEqual(layer.w13_scale_bias.data.shape, @@ -154,23 +193,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) # new quant version weight self.quant_method.new_quant_version = True - new_layer.w13_weight.data = torch.zeros( - (self.experts, self.input_size, self.output_size), - dtype=torch.int8) - new_layer.w2_weight.data = torch.zeros( - (self.experts, self.output_size // 2, self.input_size), - dtype=torch.int8) - w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), - dtype=torch.float32) - new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, - requires_grad=False) - w2_scale_bias = torch.zeros( - (self.experts, self.output_size, 16 // self.quant_method.tp_size), - dtype=torch.float32) - new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, - requires_grad=False) + new_layer = self.build_layer(is_new_quant_version=True) self.quant_method.process_weights_after_loading(new_layer) self.assertEqual(new_layer.w13_scale_bias.data.shape, (self.experts, 2 * self.input_size)) self.assertEqual(new_layer.w2_scale_bias.data.shape, (self.experts, self.output_size)) + self.assertFalse(hasattr(new_layer, "w13_weight_scale_second")) + # per-channel weight + self.quant_method.is_per_channel_weight = True + per_channel_layer = self.build_layer(is_new_quant_version=True, + is_per_channel_weight=True) + self.quant_method.process_weights_after_loading(per_channel_layer) + self.assertEqual(new_layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) diff --git a/tests/ut/torchair/test_torchair_attention.py b/tests/ut/torchair/test_torchair_attention.py new file mode 100644 index 0000000..dd262dc --- /dev/null +++ b/tests/ut/torchair/test_torchair_attention.py @@ -0,0 +1,95 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.attention.backends.abstract import AttentionType +from vllm.distributed.parallel_state import GroupCoordinator + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.torchair.torchair_attention import \ + AscendAttentionTorchairBackendImpl + + +class TestAscendAttentionTorchairBackendImpl(TestBase): + + @patch("torch.zeros") + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO + @patch("vllm.distributed.get_tensor_model_parallel_world_size", + return_value=2) # TODO + @patch("vllm.config.get_current_vllm_config") # TODO + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO + def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp, + mock_zeros): + mock_tp.world_size = 2 # TODO + ascend_config.torchair_graph_config.enabled = True # TODO + ascend_config.torchair_graph_config.enable_kv_nz = False # TODO + speculative_config = MagicMock() + speculative_config.num_speculative_tokens = 4 + vllm_config.speculative_config = speculative_config + + num_heads = 32 + head_size = 128 # TODO + scale = 0.1 # TODO + num_kv_heads = 4 + kv_cache_dtype = "auto" + attn_type = AttentionType.DECODER + mock_zeros.return_value = torch.ones((), + device='cpu', + dtype=torch.int32) + + self.impl = AscendAttentionTorchairBackendImpl( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=kv_cache_dtype, + blocksparse_params=None, + logits_soft_cap=None, + attn_type=attn_type, + kv_sharing_target_layer_name=None) + + @patch("torch_npu.npu_scatter_nd_update_") + @patch("torch_npu.npu_fused_infer_attention_score") + def test_forward_with_decode_only(self, mock_fused, _): + layer = MagicMock() + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + + seq_len = 1 + num_tokens = 100 + num_blocks = 256 + block_size = 4 + + query = torch.randn(num_tokens, seq_len, + self.impl.num_heads * self.impl.head_size) + key = torch.randn(num_tokens, seq_len, + self.impl.num_kv_heads * self.impl.head_size) + value = torch.randn(num_tokens, seq_len, + self.impl.num_kv_heads * self.impl.head_size) + kv_cache = (torch.randn(num_blocks, block_size, + self.impl.num_heads * self.impl.head_size), + torch.randn(num_blocks, block_size, + self.impl.num_heads * self.impl.head_size)) + output = torch.randn(num_tokens, self.impl.num_heads, + self.impl.head_size) + + decode = MagicMock() # TODO + decode.seq_lens_list = [2] * num_tokens + decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32) + decode.attn_mask = None + + metadata = MagicMock() + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32) + metadata.decode = decode + + mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads, + self.impl.head_size), + torch.ones(1)) + + result = self.impl.forward(layer, query, key, value, kv_cache, + metadata, output, False) + self.assertEqual(result.shape[0], num_tokens) diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 6ee983a..ec8ddfd 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -190,12 +190,15 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + ascend_config = MagicMock() ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", return_value=ascend_config): - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) self.assertEqual(builder.block_size, @@ -216,7 +219,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_vllm_config.speculative_config = None + + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) input_batch = MagicMock() @@ -250,9 +256,12 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", return_value=ascend_config): - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) input_batch = MagicMock() @@ -285,7 +294,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_vllm_config.speculative_config = None + + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -305,7 +317,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_vllm_config.speculative_config = None + + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -326,7 +341,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + mock_vllm_config.speculative_config = None + + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -351,7 +369,11 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.model_config.dtype = torch.float16 mock_device = 'cpu' + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder( + None, + None, mock_vllm_config, mock_device, metadata_cls=AscendMLATorchairMetadata) @@ -416,7 +438,11 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): model = MagicMock(spec=nn.Module) model.model = MagicMock(spec=nn.Module) + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder( + None, + None, mock_vllm_config, mock_device, metadata_cls=AscendMLATorchairMetadata) @@ -437,14 +463,16 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): max_query_len=1, decode_token_per_req=torch.tensor([1, 1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([1, 1]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill) + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) - metadata = builder.build(common_attn_metadata, model) + metadata = builder.build(1, common_attn_metadata, model) self.assertIsInstance(metadata, AscendMLATorchairMetadata) self.assertEqual(metadata.num_input_tokens, 0) diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index fb526b5..edd3fc2 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE from vllm_ascend.torchair import utils @@ -135,15 +134,3 @@ class TestTorchairUtils(TestBase): utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) mock_npu_cast.assert_not_called() - - def test_torchair_quant_method_register(self): - - TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W8A8_DYNAMIC"] - TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W4A8_DYNAMIC"] - utils.torchair_quant_method_register() - self.assertNotEqual(TorchairW8A8DYNAMICQuantizer, - SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"]) - self.assertNotEqual(TorchairW4A8DYNAMICQuantizer, - SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"]) diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py index a72dbdc..703098d 100644 --- a/tests/ut/worker/test_input_batch.py +++ b/tests/ut/worker/test_input_batch.py @@ -24,8 +24,8 @@ from vllm.utils import make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable +from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch VOCAB_SIZE = 1024 diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py new file mode 100644 index 0000000..70b7c7d --- /dev/null +++ b/tests/ut/worker/test_model_runner_v1.py @@ -0,0 +1,107 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from unittest.mock import MagicMock, patch + +import pytest + +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.utils import AscendSocVersion +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + +# yapf: disable +@pytest.mark.parametrize( + "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method", + [ + # Case 1: Expert parallel is disabled, should always be 'allgather' + (AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER), + (AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER), + + # Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2 + (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL), + (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL), + (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition + + # Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather + (AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER), + (AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER), + + # Case 4: A3 SOC + (AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2), + (AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL), + ]) +# yapf: enable +def test_select_moe_comm_method(soc_version, enable_expert_parallel, + world_size, num_tokens, mc2_tokens_capacity, + quant_type, expected_method): + """ + Tests the _select_moe_comm_method with various configurations including quant_type. + """ + # Mock the NPUModelRunner instance and its dependencies + mock_runner = MagicMock(spec=NPUModelRunner) + mock_runner.parallel_config = MagicMock() + mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel + mock_runner.parallel_config.world_size_across_dp = world_size + mock_runner.mc2_tokens_capacity = mc2_tokens_capacity + + # Add vllm_config.model_config.hf_config mock with moe_quantize + mock_hf_config = MagicMock() + mock_hf_config.moe_quantize = quant_type + mock_model_config = MagicMock() + mock_model_config.hf_config = mock_hf_config + mock_vllm_config = MagicMock() + mock_vllm_config.model_config = mock_model_config + mock_runner.vllm_config = mock_vllm_config + + # Patch the helper functions + with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version', + return_value=soc_version), \ + patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank', + return_value=True): + + # Bind the real method to the mock object + method = NPUModelRunner._select_moe_comm_method( + mock_runner, num_tokens, False) + + # Assert the result + assert method == expected_method + + +def test_select_moe_comm_method_unsupported_soc(): + """ + Tests that _select_moe_comm_method raises ValueError for an unsupported SOC. + """ + mock_runner = MagicMock(spec=NPUModelRunner) + mock_runner.parallel_config = MagicMock() + mock_runner.parallel_config.enable_expert_parallel = True + mock_runner.mc2_tokens_capacity = 256 + + # Add vllm_config.model_config.hf_config mock with moe_quantize + mock_hf_config = MagicMock() + mock_hf_config.moe_quantize = None + mock_model_config = MagicMock() + mock_model_config.hf_config = mock_hf_config + mock_vllm_config = MagicMock() + mock_vllm_config.model_config = mock_model_config + mock_runner.vllm_config = mock_vllm_config + + unsupported_soc = "UnsupportedSOC" + + with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version', + return_value=unsupported_soc), \ + patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank', + return_value=True), \ + pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"): + + NPUModelRunner._select_moe_comm_method(mock_runner, 100, False) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index af3d904..eb05a7a 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -258,7 +258,7 @@ class TestNPUWorker(TestBase): # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): worker = NPUWorker() - + worker._sleep_saved_buffers = {} # Test wake_up method worker.wake_up(tags=["test_tag"]) @@ -355,6 +355,28 @@ class TestNPUWorker(TestBase): self.assertIn("Profiler is not enabled", str(cm.exception)) + @patch("vllm_ascend.worker.worker_v1.envs_vllm") + @patch("vllm_ascend.worker.worker_v1.envs_ascend") + def test_profile_and_msmonitor_both_enabled_raises_error( + self, mock_envs_vllm, mock_envs_ascend): + """Test profile method raises exception when both profiler and msmonitor are enabled""" + from vllm_ascend.worker.worker_v1 import NPUWorker + + mock_envs_vllm.VLLM_TORCH_PROFILER_DIR = "/path/to/traces" + mock_envs_ascend.MSMONITOR_USE_DAEMON = 1 + + # Create worker mock + with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): + worker = NPUWorker() + + # Test should raise exception + with self.assertRaises(RuntimeError) as cm: + _ = worker._init_profiler() + + self.assertIn( + "MSMONITOR_USE_DAEMON and VLLM_TORCH_PROFILER_DIR cannot be both set at the same time.", + str(cm.exception)) + def test_lora_methods(self): """Test LoRA related methods""" from vllm_ascend.worker.worker_v1 import NPUWorker @@ -828,6 +850,7 @@ class TestNPUWorker(TestBase): # Mock scheduler_output and return result mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 # Create a real ModelRunnerOutput instance or mock mock_model_output = MagicMock(spec=ModelRunnerOutput) worker.model_runner.execute_model.return_value = mock_model_output @@ -842,9 +865,8 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.worker.worker_v1.get_pp_group") @patch("vllm_ascend.worker.worker_v1.get_tp_group") - @patch("vllm_ascend.worker.worker_v1.has_kv_transfer_group") - def test_execute_model_middle_rank(self, mock_has_kv_transfer_group, - mock_get_tp_group, mock_get_pp_group): + def test_execute_model_middle_rank(self, mock_get_tp_group, + mock_get_pp_group): """Test execute_model method - middle rank case""" from vllm.sequence import IntermediateTensors @@ -875,10 +897,8 @@ class TestNPUWorker(TestBase): ) worker.model_runner.execute_model.return_value = mock_intermediate_output - # Set has_kv_transfer_group returns False - mock_has_kv_transfer_group.return_value = False - mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 # Test execute_model result = worker.execute_model(mock_scheduler_output) @@ -926,6 +946,7 @@ class TestNPUWorker(TestBase): # Mock return result mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 mock_model_output = MagicMock(spec=ModelRunnerOutput) worker.model_runner.execute_model.return_value = mock_model_output @@ -1009,7 +1030,9 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything") @patch("vllm_ascend.worker.worker_v1.logger") - def test_compile_or_warm_up_model_with_eager_mode(self, mock_logger, + @patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb") + def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, + mock_logger, mock_seed_everything): """Test compile_or_warm_up_model method - eager mode""" from vllm_ascend.worker.worker_v1 import NPUWorker @@ -1051,10 +1074,14 @@ class TestNPUWorker(TestBase): # Verify seed setting mock_seed_everything.assert_called_once_with(12345) + # Verify atb warm up + mock_warm_up_atb.assert_called_once() + @patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything") @patch("vllm_ascend.worker.worker_v1.logger") + @patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb") def test_compile_or_warm_up_model_with_graph_capture( - self, mock_logger, mock_seed_everything): + self, mock_warm_up_atb, mock_logger, mock_seed_everything): """Test compile_or_warm_up_model method - with graph capture enabled""" from vllm_ascend.worker.worker_v1 import NPUWorker @@ -1087,6 +1114,9 @@ class TestNPUWorker(TestBase): # Verify seed setting mock_seed_everything.assert_called_once_with(67890) + # Verify atb warm up + mock_warm_up_atb.assert_called_once() + @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") def test_initialize_from_config_with_sleep_mode(self, mock_allocator_class): @@ -1141,3 +1171,55 @@ class TestNPUWorker(TestBase): # Verify calls worker.model_runner.initialize_kv_cache.assert_called_once_with( mock_kv_cache_config) + + @patch("vllm_ascend.worker.worker_v1.get_pp_group") + @patch("vllm_ascend.worker.worker_v1.get_tp_group") + @patch("vllm_ascend.worker.worker_v1.EMPTY_MODEL_RUNNER_OUTPUT") + def test_execute_model_kv_connector_not_finished(self, mock_empty_output, + mock_get_tp_group, + mock_get_pp_group): + """Test execute_model method - kv_connector_output not finished sending/recving case""" + from vllm.sequence import IntermediateTensors + + from vllm_ascend.worker.worker_v1 import NPUWorker + + # Create worker mock + with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): + worker = NPUWorker() + worker.model_runner = MagicMock() + worker.vllm_config = MagicMock() + worker.vllm_config.parallel_config = MagicMock() + worker.vllm_config.parallel_config.distributed_executor_backend = "ray" + + # Set as middle rank (not first, not last) + mock_pp_group = MagicMock() + mock_pp_group.is_first_rank = False + mock_pp_group.is_last_rank = False + mock_get_pp_group.return_value = mock_pp_group + + # Setup tensor reception data + mock_pp_group.recv_tensor_dict.return_value = {"tensor": "data"} + + # Create mock kv_connector_output - both finished_sending and finished_recving are False + mock_kv_connector_output = MagicMock() + mock_kv_connector_output.finished_sending = False + mock_kv_connector_output.finished_recving = False + + # Mock return IntermediateTensors with kv_connector_output + mock_intermediate_output = MagicMock(spec=IntermediateTensors) + mock_intermediate_output.tensors = {"output_tensor": "data"} + mock_intermediate_output.kv_connector_output = mock_kv_connector_output + worker.model_runner.execute_model.return_value = mock_intermediate_output + + mock_scheduler_output = MagicMock() + mock_scheduler_output.total_num_scheduled_tokens = 1 + + # Test execute_model + result = worker.execute_model(mock_scheduler_output) + + # Verify tensor reception and sending + mock_pp_group.recv_tensor_dict.assert_called_once() + mock_pp_group.send_tensor_dict.assert_called_once() + + # When both finished_sending and finished_recving are False, should return EMPTY_MODEL_RUNNER_OUTPUT directly + self.assertEqual(result, mock_empty_output) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 7588e70..90aede7 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -23,5 +23,7 @@ def register(): def register_model(): + import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa + from .models import register_model register_model() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index e46cd9a..65ea3ea 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -34,6 +34,8 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} + self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32" + self.use_sfa = self.is_deepseek_sfa torchair_graph_config = additional_config.get("torchair_graph_config", {}) @@ -43,13 +45,26 @@ class AscendConfig: "ascend_scheduler_config", {}) self.ascend_scheduler_config = AscendSchedulerConfig( ascend_scheduler_config) - + # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config self.expert_map_path = additional_config.get("expert_map_path", None) + self.expert_map_record_path = additional_config.get( + "expert_map_record_path", + None) # Provide path to export expert map + self.init_redundancy_expert = additional_config.get( + "init_redundancy_expert", 0) + self.dynamic_eplb = additional_config.get("dynamic_eplb", False) + self.num_iterations_eplb_update = additional_config.get( + "num_iterations_eplb_update", 400) + self.gate_eplb = additional_config.get("gate_eplb", False) + self.num_wait_worker_iterations = additional_config.get( + "num_wait_worker_iterations", 30) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + self.multistream_overlap_shared_expert = additional_config.get( + "multistream_overlap_shared_expert", False) self.enable_prefetch = additional_config.get("enable_prefetch", False) self.lmhead_tensor_parallel_size = additional_config.get( "lmhead_tensor_parallel_size", None) @@ -61,6 +76,24 @@ class AscendConfig: raise AssertionError( "lmhead_tensor_parallel_size is only supported in the pure DP scenario" ) + self.oproj_tensor_parallel_size = additional_config.get( + "oproj_tensor_parallel_size", None) + if self.oproj_tensor_parallel_size is not None: + logger.info( + f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario" + ) + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in the pure DP scenario" + ) + if not self.torchair_graph_config.enabled: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in graph mode" + ) + if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." + ) class TorchairGraphConfig: @@ -81,10 +114,10 @@ class TorchairGraphConfig: "graph_batch_sizes_init", False) self.enable_multistream_mla = torchair_graph_config.get( "enable_multistream_mla", False) - self.enable_multistream_moe = torchair_graph_config.get( - "enable_multistream_moe", False) self.enable_view_optimize = torchair_graph_config.get( "enable_view_optimize", True) + self.enable_frozen_parameter = torchair_graph_config.get( + "enable_frozen_parameter", True) self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) if not isinstance(self.graph_batch_sizes, list): @@ -117,10 +150,6 @@ class TorchairGraphConfig: raise RuntimeError( "enable_multistream_mla is valid only when Torchair graph mode is enabled" ) - if self.enable_multistream_moe: - raise RuntimeError( - "enable_multistream_moe is valid only when Torchair graph mode is enabled" - ) if self.enable_kv_nz: raise RuntimeError( "enable_kv_nz is valid only when Torchair graph mode is enabled" @@ -200,14 +229,8 @@ def check_ascend_config(vllm_config, enforce_eager): "it has been disabled automatically.") # aclgraph case else: - # aclgraph doesn't work with deepseek model and only qwen model is well tested. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type - if "deepseek" in model_type: - raise NotImplementedError( - "ACL Graph does not support deepseek. Please " - "try torchair graph mode to serve deepseek models on vllm-ascend." - " Or set `enforce_eager=True` to use eager mode.") if "qwen" not in model_type: logger.warning( "ACL Graph is currently experimental. Please " diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 601f33a..607f029 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,6 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context, set_forward_context) import vllm_ascend.envs as envs_ascend +from vllm_ascend.utils import enable_sp class FusedMoEState(Enum): @@ -22,6 +23,13 @@ class FusedMoEState(Enum): All2AllSeq = 5 +class MoECommType(Enum): + ALLGATHER = 0 + MC2 = 1 + ALLTOALL = 2 + NAIVE_MULTICAST = 3 + + # TODO(zzzzwwjj): add soc_version to choose branch def _get_fused_moe_state(ep_size: int, with_prefill: bool, is_deepseek_v3_r1: bool): @@ -42,18 +50,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool, return FusedMoEState.MC2 -def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str: - if ep_size == 1: - return "TokenDispatcherWithAllGather" - - if ep_size < 16: - return "TokenDispatcherWithAll2AllV" - - if with_prefill: - return "TokenDispatcherWithAll2AllV" - return "TokenDispatcherWithMC2" - - @contextmanager def set_ascend_forward_context( attn_metadata: Any, @@ -64,10 +60,12 @@ def set_ascend_forward_context( with_prefill: bool = True, in_profile_run: bool = False, reserved_mc2_mask: Optional[torch.Tensor] = None, - moe_comm_method: str = "", + moe_comm_type: Optional[MoECommType] = None, num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + prefetch_stream: torch.npu.Stream = None, + model_instance: torch.nn.Module = None): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. @@ -82,8 +80,13 @@ def set_ascend_forward_context( batch_descriptor=batch_descriptor, ): forward_context = get_forward_context() - forward_context.moe_comm_method_name = moe_comm_method + "commimpl" + + from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method + forward_context.moe_comm_type = moe_comm_type + forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) + forward_context.with_prefill = with_prefill + tp_world_size = get_tensor_model_parallel_world_size() ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) @@ -95,16 +98,63 @@ def set_ascend_forward_context( forward_context.fused_moe_state = fused_moe_state forward_context.in_profile_run = in_profile_run - from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - get_token_dispatcher - dispatcher_name = get_dispatcher_name(ep_size, with_prefill) - dispatcher = get_token_dispatcher(dispatcher_name) - forward_context.token_dispatcher = dispatcher - # NOTE: This cannot be set using set_forward_context # due to multiple warmups before actual capturing forward_context.capturing = False + # set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature. + # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, + # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, + # the performance may degrade due to the switching of communication methods. + sp_enabled = enable_sp(vllm_config) and \ + tp_world_size > 1 and \ + num_tokens is not None and num_tokens > 1000 + + if sp_enabled: + pad_size = (tp_world_size - + (num_tokens % tp_world_size)) % tp_world_size + forward_context.pad_size = pad_size + forward_context.sp_enabled = sp_enabled + + # set this for rope forward_oot using + forward_context.is_first_layer = True + + # set layer_idx to enable optimization features that depend on this information. + # This is only applicable to models that contain these necessary attributes. + forward_context.layer_idx = None + if model_instance is not None and \ + hasattr(model_instance, "model") and \ + hasattr(model_instance.model, "start_layer"): + forward_context.layer_idx = model_instance.model.start_layer + + # set for mlp weight prefetch + prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ + envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ + forward_context.layer_idx is not None and \ + num_tokens is not None and num_tokens < 500 + if prefetch_mlp_enabled: + forward_context.prefetch_stream = prefetch_stream + forward_context.model_instance = model_instance + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled + + # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. + # It will be improved later by implementing operator fusion through the FX graph. + # + # set for addrmsnorm+quant fusion. + # this optim now just support dense models due to the specific operators used. + # Once the necessary conditions are met, support for MOE models will also be added. + from vllm_ascend.quantization.quant_config import AscendQuantConfig + addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ + vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \ + forward_context.layer_idx is not None + if addrmsnorm_quant_fusion_enabled: + forward_context.model_instance = model_instance + forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" + forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens @@ -120,7 +170,6 @@ def set_ascend_forward_context( if num_tokens is not None: if num_actual_tokens is None: num_actual_tokens = num_tokens - tp_world_size = get_tensor_model_parallel_world_size() # NOTE: token num which need to pad to when mc2 forward_context.padded_num_tokens = math.ceil( max_tokens_across_dp / tp_world_size) * tp_world_size diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index a0e6334..225d4b9 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -39,11 +39,22 @@ class AttentionMaskBuilder: self, max_seq_len: int, dtype: torch.dtype, + device: torch.device = None, ): + # NOTE: The device argument specifies the target NPU + # to be used for the newly added FIA operator. + # Only pass this parameter when using the new FIA operator. + attn_mask = _generate_attn_mask(max_seq_len, dtype) self._seq_len_cached = attn_mask.shape[0] self.attn_mask_cache = attn_mask + self.device = device + if torch.version.cann.startswith("8.3"): + assigned_mask_dim = 2048 + self.chunked_prefill_attn_mask = torch.triu( + torch.ones(assigned_mask_dim, assigned_mask_dim), + diagonal=1).to(torch.int8).to(device) @staticmethod def get_mask_scale_factor(dtype: torch.dtype = torch.float16): @@ -62,28 +73,32 @@ class AttentionMaskBuilder: device: torch.device): self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( - ).to(device) + ).to(device, non_blocking=True) def get_splitfuse_attn_mask( self, - seq_lens: torch.Tensor, - position: torch.Tensor, - dtype: torch.dtype, - device: torch.device, + seq_lens: torch.Tensor = None, + position: torch.Tensor = None, + dtype: torch.dtype = None, + device: torch.device = None, ) -> torch.Tensor: - if dtype not in [torch.float16, torch.bfloat16]: - raise ValueError( - "splitfuse_attn_mask now only supports bf16 and fp16") - max_seq_len = max(seq_lens, default=0) - self._update_attn_cache(max_seq_len, dtype) - # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation - # is not the same. Fix this in the future when kernel is ready. - mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype) - attn_mask = torch.index_select(self.attn_mask_cache, - dim=0, - index=position)[:, :max_seq_len] - attn_mask *= mask_scale_factor - return attn_mask.contiguous().to(device, non_blocking=True) + if torch.version.cann.startswith("8.3"): + return self.chunked_prefill_attn_mask + else: + if dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + "splitfuse_attn_mask now only supports bf16 and fp16") + max_seq_len = max(seq_lens, default=0) + self._update_attn_cache(max_seq_len, dtype) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor( + dtype) + attn_mask = torch.index_select(self.attn_mask_cache, + dim=0, + index=position)[:, :max_seq_len] + attn_mask *= mask_scale_factor + return attn_mask.contiguous().to(device, non_blocking=True) def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): if seqlen > self._seq_len_cached: diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5460b94..d289bb4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -17,24 +17,27 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Tuple, Type +from typing import ClassVar, List, Optional, Tuple, Type import torch import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) -from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import cdiv, direct_register_custom_op +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector) +from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) -from vllm_ascend.worker.npu_input_batch import InputBatch class AscendAttentionBackend(AttentionBackend): @@ -52,10 +55,6 @@ class AscendAttentionBackend(AttentionBackend): def get_metadata_cls() -> Type["AscendMetadata"]: return AscendMetadata - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: return AscendAttentionMetadataBuilder @@ -111,6 +110,10 @@ class AscendAttentionBackend(AttentionBackend): key_caches[dst_indices] = key_caches[src_indices] value_caches[dst_indices] = value_caches[src_indices] + @staticmethod + def get_supported_block_size() -> list[int]: + return [64] + class AscendAttentionState(Enum): PrefillNoCache = 0 @@ -155,48 +158,50 @@ class AscendMetadata: # *************************** Other Properties *************************** # enable_dbo_across_dp: bool = False - is_only_prefill: bool = False class AscendAttentionMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + # Does this backend/builder reorder the batch? + # If not, set this to None. Otherwise set it to the query + # length that will be pulled into the front of the batch. + reorder_batch_threshold: ClassVar[int] = 1 def __init__( self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device - self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, - vllm_config.cache_config.block_size) + self.max_num_blocks_per_req = cdiv( + self.model_config.max_model_len, + AscendAttentionBackend.get_supported_block_size()[0]) - def reorder_batch(self, input_batch: "InputBatch", + def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: return False def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + model: Optional[nn.Module] = None, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] - block_table = common_attn_metadata.block_table_tensor - block_table[:num_reqs, :self.max_num_blocks_per_req] = ( - block_table[:num_reqs]) - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - self.device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: @@ -225,8 +230,25 @@ class AscendAttentionMetadataBuilder: slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, - is_only_prefill=common_attn_metadata.is_only_prefill) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) + return attn_metadata + + def build_for_graph_capture( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + ): + if attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state return attn_metadata @@ -265,20 +287,6 @@ class AscendAttentionBackendImpl(AttentionImpl): self.key_cache = None self.value_cache = None - def _repeat_kv(self, hidden_states: torch.Tensor, - n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, None, :, :].expand( - num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(num_key_value_heads * n_rep, slen, - head_dim) - def _forward_prefill_no_cache( self, query: torch.Tensor, @@ -304,34 +312,15 @@ class AscendAttentionBackendImpl(AttentionImpl): mask = torch_npu.npu_format_cast(mask.contiguous(), ACL_FORMAT_FRACTAL_NZ) - if self.sliding_window is not None and \ - attn_metadata.attn_mask.shape[0] > self.sliding_window: - - key = self._repeat_kv(key, self.num_heads // self.num_kv_heads) - value = self._repeat_kv(value, self.num_heads // self.num_kv_heads) - - output, _ = torch_npu.npu_fused_infer_attention_score( - query, - key, - value, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="TND", - pre_tokens=self.sliding_window, - scale=self.scale, - actual_seq_lengths=attn_metadata.seq_lens, - actual_seq_lengths_kv=attn_metadata.seq_lens) - output = output.view(num_tokens, self.num_heads, self.head_size) - else: - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) assert output is not None return output[:num_tokens, :, :] @@ -372,7 +361,8 @@ class AscendAttentionBackendImpl(AttentionImpl): # seq_lens_tensor needs to be transferred to the device for 310P. attn_metadata.seq_lens = \ attn_metadata.seq_lens.to(device=query.device) - if self.sliding_window is not None: + if self.sliding_window is not None and attn_metadata.seq_lens.shape[ + 0] == query.size(0): batch_size = attn_metadata.seq_lens.shape[0] block_size = 128 query = query.view(batch_size, 1, self.num_heads * self.head_size) @@ -399,16 +389,53 @@ class AscendAttentionBackendImpl(AttentionImpl): output = output.view(batch_size, self.num_heads, self.head_size) else: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + num_tokens = query.shape[0] + if forward_context.capturing: + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + graph_params.attn_params[num_tokens].append(( + query, + self.key_cache, + self.value_cache, + self.num_kv_heads, + self.num_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + output, + )) + + torch.npu.graph_task_group_begin(stream) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) return output def _forward_v1_style( @@ -449,18 +476,43 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata.seq_lens = \ attn_metadata.seq_lens.to(device=query.device) - torch_npu._npu_paged_attention_splitfuse( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + if torch.version.cann.startswith("8.3"): + # TODO:The npu_fused_infer_attention_score op is planned to + # be utilized in a wider range in upcoming versions. + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.query_start_loc[1:], + actual_seq_lengths_kv=attn_metadata.seq_lens, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + else: + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) return output def forward( @@ -554,12 +606,18 @@ class AscendAttentionBackendImpl(AttentionImpl): output) # Normal V1 situation. else: + if torch.version.cann.startswith("8.3"): + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] output = self._forward_v1_style(query, attn_metadata, output) # to make in-place change to the output tensor if hasattr(layer, 'quant_method') and use_kv_cache_int8: output = output.view(num_tokens, self.num_heads, self.head_size) - ori_output[:, :, :] = output[:num_tokens, :, :] + ori_output[:num_tokens, :, :] = output[:num_tokens, :, :] return output.view(num_tokens, self.hidden_size) @@ -570,8 +628,11 @@ def unified_ascend_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, @@ -582,6 +643,7 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) return diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a386f63..73cbae6 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar +from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, + TypeVar) import torch import torch_npu @@ -12,15 +13,17 @@ from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills) + maybe_save_kv_layer_to_connector, + split_decodes_and_prefills, + wait_for_kv_layer_from_connector) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch @@ -164,6 +167,9 @@ M = TypeVar("M", bound=AscendMLAMetadata) class AscendMLAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -171,6 +177,8 @@ class AscendMLAMetadataBuilder: # _attn_mask_builder = None def __init__(self, + kv_cache_spec, + layer_names, vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[AscendMLAMetadata] = None): @@ -185,7 +193,16 @@ class AscendMLAMetadataBuilder: self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.speculative_config = vllm_config.speculative_config self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + self.reorder_batch_threshold = self.decode_threshold if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -265,6 +282,7 @@ class AscendMLAMetadataBuilder: def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, ) -> AscendMLAMetadata: @@ -272,7 +290,6 @@ class AscendMLAMetadataBuilder: num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) assert num_decodes + num_prefills == num_reqs @@ -284,11 +301,7 @@ class AscendMLAMetadataBuilder: device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) @@ -376,11 +389,12 @@ class AscendMLAMetadataBuilder: decode_metadata = None if num_decodes > 0: + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decode_tokens] + seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decode_tokens, ...] + block_table = block_table[:num_decodes, ...] seq_lens_list = seq_lens.tolist() cos = self.cos_cache[input_positions].unsqueeze( # type: ignore @@ -481,17 +495,12 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.enable_prefetch self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla vllm_config = get_current_vllm_config() self.ring_mla_mask_size = 512 self.prefill_mask = None - # Adapt torch air graph mode with spec decoding. - speculative_config = vllm_config.speculative_config - if speculative_config is not None: - self.spec_token_num = speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 + self.speculative_config = vllm_config.speculative_config def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -663,84 +672,47 @@ class AscendMLAImpl(MLAAttentionImpl): self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - query = torch.cat((q_nope, q_pe), dim=-1) - key = torch.cat((k_nope, k_pe), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) - elif self.chunked_prefill_for_mla: - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=q_nope.device) - if self.prefill_mask is None: - self.prefill_mask = torch.triu( - torch.ones(self.ring_mla_mask_size, - self.ring_mla_mask_size, - device=q_nope.device, - dtype=q_nope.dtype), 1) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=self.prefill_mask, - seqlen=torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32), - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - else: - query = torch.cat((q_nope, q_pe), dim=-1) - attn_output_torch = torch.empty(num_tokens, - self.num_heads * self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output_torch, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_seq_lens, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=q_nope.device) + if self.prefill_mask is None: + if q_nope.dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(q_nope.dtype) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=torch.tensor( + attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not self.chunked_prefill_for_mla: - attn_output = attn_output_torch return attn_output def exec_kv_decode( @@ -785,7 +757,7 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -840,8 +812,11 @@ class AscendMLAImpl(MLAAttentionImpl): self.qk_rope_head_dim) input_layout = "BNSD" - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill + ] and self.speculative_config is not None: + # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) @@ -887,8 +862,8 @@ class AscendMLAImpl(MLAAttentionImpl): current_ms_metadata.before_comm_event.wait() return self._v_up_proj(attn_output) - def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv): + def _mla_preprocess(self, layer_name, hidden_states, kv_cache, + attn_metadata, need_gather_q_kv): # MLA Preprocess: # 1. Perform q_a_proj and q_a_layernorm to obtain q_c # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split @@ -917,6 +892,8 @@ class AscendMLAImpl(MLAAttentionImpl): kv_no_split = get_tp_group().all_gather(kv_no_split, 0) decode_preprocess_res = None prefill_preprocess_res = None + if has_prefill: + wait_for_kv_layer_from_connector(layer_name) # Preprocess for decode tokens if has_decode: decode_q_c = q_c[:num_decode_tokens] @@ -963,6 +940,7 @@ class AscendMLAImpl(MLAAttentionImpl): def forward( self, + layer_name, hidden_states: torch.Tensor, # query in unified attn kv_cache: Tuple[torch.Tensor], attn_metadata: M, @@ -989,7 +967,8 @@ class AscendMLAImpl(MLAAttentionImpl): # MLA Preprocess decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( - hidden_states, kv_cache, attn_metadata, need_gather_q_kv) + layer_name, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv) if decode_preprocess_res is not None: # MLA Preprocess for decoding @@ -1047,4 +1026,8 @@ class AscendMLAImpl(MLAAttentionImpl): is_force_scatter=self.enable_shared_expert_dp)[0] current_ms_metadata.after_comm_event.record() del o_proj_input + + has_prefill = attn_metadata.num_prefills > 0 + if has_prefill: + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py new file mode 100644 index 0000000..55282c8 --- /dev/null +++ b/vllm_ascend/attention/sfa_v1.py @@ -0,0 +1,986 @@ +from dataclasses import dataclass +from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, + TypeVar) + +import torch +import torch_npu +from torch import nn +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFAMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFAMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["AscendSFAImpl"]: + return AscendSFAImpl + + +@dataclass +class AscendSFAPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] + seq_lens: list[int] + + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + sin: torch.Tensor + cos: torch.Tensor + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class AscendSFADecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendSFAMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendSFADecodeMetadata] = None + prefill: Optional[AscendSFAPrefillMetadata] = None + enable_dbo_across_dp: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendSFAMetadata"]: + """Split metadata for multi-stream with AscendSFAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + + +M = TypeVar("M", bound=AscendSFAMetadata) + + +class AscendSFAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFAMetadata] = None): + self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFAMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= self.decode_threshold: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFAMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[reqs_start:].max().item() + max_seq_lens = seq_lens[reqs_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendSFAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + actual_query_lens = torch.tensor(query_lens[reqs_start:], + dtype=torch.int32).npu() + query_lens_prefill_sfa = torch.cumsum(actual_query_lens, + dim=0).to(torch.int32) + seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu() + prefill_metadata = AscendSFAPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens_prefill_sfa, + seq_lens=seq_lens_prefill_sfa, + context_lens=seq_lens[reqs_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + decode_metadata = None + if num_decodes > 0: + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( + torch.int32).npu() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + seq_lens_list = seq_lens.tolist() + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendSFADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + ) + + +class PrefillSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + + +class DecodeSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + # nope_cache: Optional[torch.Tensor] = None + # rope_cache: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + bsz: Optional[int] = None + + +class AscendSFAImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + if self.q_a_proj is not None: + self.q_b_proj = self.q_proj + else: + self.q_b_proj = None + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.enable_prefetch + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + # indexer param + self.dim = self.indexer.dim + self.n_heads: int = self.indexer.n_heads # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.index_topk: int = self.indexer.index_topk # 2048 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + self.softmax_scale = self.indexer.softmax_scale + + # Adapt torch air graph mode with spec decoding. + speculative_config = vllm_config.speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + self.cp_size = 1 + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + # SFA Preprocess: + # 1. Perform q_a_proj and q_a_layernorm to obtain q_c + # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split + # 3. If need_gather_q_kv, perform all_gather. + # 4. Preprocess decode tokens, write kv cache and get: + # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope + # 5. Preprocess prefill tokens, write kv cache and get: + # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + if need_gather_q_kv: + # q_c = get_tp_group().all_gather(q_c, 0) + # kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + hidden_states = get_tp_group().all_gather(hidden_states, 0) + # hidden_states_decode = hidden_states[:num_decode_tokens] + # if self.q_a_proj is not None: + # npu_prefetch(self.q_a_proj.weight, + # hidden_states, + # enabled=self.enable_prefetch) + # ckq = self.q_a_proj(hidden_states) # q down + # q_c = self.q_a_layernorm(ckq) # q down layernorm + # else: + # q_c = hidden_states + + # kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv + # Process for shared_expert_dp + + decode_preprocess_res = None + prefill_preprocess_res = None + # Preprocess for decode tokens + if has_decode: + q_len = 1 + hidden_states_decode = hidden_states[:num_decode_tokens] + decode_kq = self.q_a_proj(hidden_states_decode) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + decode_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_decode) # c_kv + + # decode_q_c = q_c[:num_decode_tokens] + decode_slot_mapping = attn_metadata.slot_mapping[: + num_decode_tokens] + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + decode_q = self.q_b_proj(decode_q_c) + bsz, _ = decode_q.shape + decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim) + decode_q_nope, decode_q_pe = torch.split( + decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + decode_q_nope = decode_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + decode_q_nope = (torch.matmul(decode_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(bsz, q_len, + self.num_heads, + self.kv_lora_rank)) + + # stream2 kv + key_cache = kv_cache[0] + value_cache = kv_cache[1] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1) + decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + decode_kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + decode_slot_mapping.to(torch.int64), + value_cache, + key_cache, + c_kv_scale=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode='PA') # adapter NZ + # nz_block_size = 16 + # KVCACHE_NZ_DIM = 16 + # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) + # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) + + decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos, + sin) # BNSD + + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + topk_indices = self.indexer_select(hidden_states_decode, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache) + + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_rope) + decode_preprocess_res = DecodeSFAPreprocessResult( + q_nope=decode_q_nope, + q_pe=decode_q_pe, + # nope_cache = nope_cache, + # rope_cache = rope_cache, + topk_indices=topk_indices, + query_states=query_states, + key_states=key_states, + bsz=bsz, + ) + + # Preprocess for prefill tokens + if has_prefill: + bsz = 1 + + hidden_states_prefill = hidden_states[ + num_decode_tokens:num_actual_tokens] + prefill_kq = self.q_a_proj(hidden_states_prefill) # q down + prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm + prefill_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_prefill) # c_kv + + # prefill_q_c = q_c[ + # num_decode_tokens:num_actual_tokens] + prefill_slot_mapping = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + prefill_slot_mapping = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + # prefill_kv_no_split = kv_no_split[ + # num_decode_tokens:num_actual_tokens] + # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] + prefill_qr = prefill_q_c + prefill_q = self.q_b_proj(prefill_qr) + prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) + prefill_q_nope, prefill_q_pe = torch.split( + prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + prefill_q_nope = prefill_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + prefill_q_nope = (torch.matmul(prefill_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(-1, self.num_heads, + self.kv_lora_rank)) + prefill_q_pe = prefill_q_pe.unsqueeze(2) + + # stream2 kv + + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + cos_q, sin_q = cos, sin + + # cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + # sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + prefill_q_pe = torch_npu.npu_interleave_rope( + prefill_q_pe, cos_q, sin_q) # BNSD + prefill_q_pe = prefill_q_pe.squeeze(2) #BSH + # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? + + prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) + prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + prefill_latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + self.kv_a_layernorm.weight, + cos.view(-1, 1, 1, self.qk_rope_head_dim), + sin.view(-1, 1, 1, self.qk_rope_head_dim), + prefill_slot_mapping.to(torch.int64), + rope_cache, + nope_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA") + + topk_indices = self.indexer_select(x=hidden_states_prefill, + qr=prefill_qr, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + query_states = (prefill_q_nope, prefill_q_pe) + key_states = (prefill_k_nope, prefill_k_pe) + prefill_preprocess_res = PrefillSFAPreprocessResult( + q_nope=prefill_q_nope, + q_pe=prefill_q_pe, + topk_indices=topk_indices, + k_nope=prefill_k_nope, + k_pe=prefill_k_pe, + query_states=query_states, + key_states=key_states, + ) + + return decode_preprocess_res, prefill_preprocess_res + + def forward( + self, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + num_actual_tokens = attn_metadata.num_actual_tokens + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + num_decode_tokens = attn_metadata.num_decode_tokens + # Inputs and outputs may be padded for CUDA graphs + output = output[:num_actual_tokens, ...] + o_proj_input_shape = (num_actual_tokens, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + # SFA Preprocess + decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess( + hidden_states, kv_cache, attn_metadata, need_gather_q_kv) + + if decode_preprocess_res is not None: + # bsz, q_len, _, _ = query_states[0].shape + decode_attn_output = self.apply_attention_fusion( + query_states=decode_preprocess_res.query_states, + key_states=decode_preprocess_res.key_states, + attn_metadata=attn_metadata, + topk_indices=decode_preprocess_res.topk_indices) + o_proj_input[:num_decode_tokens] = decode_attn_output + + if prefill_preprocess_res is not None: + prefill_attn_output = self.apply_attention_fusion( + query_states=prefill_preprocess_res.query_states, + key_states=prefill_preprocess_res.key_states, + attn_metadata=attn_metadata, + topk_indices=prefill_preprocess_res.topk_indices) + o_proj_input[num_decode_tokens:] = prefill_attn_output + + output[...] = self.mla_epilog(o_proj_input, absorb=True) + return output + + def apply_attention_fusion(self, query_states, key_states, topk_indices, + attn_metadata: M): + # repeat k/v heads if n_kv_heads < n_heads + q_nope, q_pe = query_states + k_nope, k_rope = key_states + + if attn_metadata.prefill is not None: + + prefill_metadata = attn_metadata.prefill + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=prefill_metadata.block_table, + actual_seq_lengths_query=prefill_metadata.query_lens, + actual_seq_lengths_kv=prefill_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + + elif attn_metadata.decode is not None: + decode_metadata = attn_metadata.decode + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.decode.block_table, + actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=decode_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.squeeze(1) + + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul(slc_fa_fusion, + self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + # Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and + # with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1. + # after reshape: [T(bs*q_len), 1, N//attn_tp_size*D] + # attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim) + + return attn_output + + def mla_epilog(self, + attn_output: torch.Tensor = None, + absorb: bool = False): + # TODO: need to check + attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0], + -1), + is_prefill=True, + is_force_scatter=False) + + return attn_output + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + ): + if attn_metadata.prefill is not None: + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + actual_seq_lengths_query = attn_metadata.prefill.query_lens + actual_seq_lengths_key = attn_metadata.prefill.seq_lens + block_table = attn_metadata.prefill.block_table + elif attn_metadata.decode is not None: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q + actual_seq_lengths_key = attn_metadata.decode.seq_lens + block_table = attn_metadata.decode.block_table + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights = self.weights_proj(x) + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 2ef537f..519cde0 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,7 +1,11 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, List import torch +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.forward_context import ForwardContext, get_forward_context @dataclass @@ -21,6 +25,13 @@ class AscendCommonAttentionMetadata: """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + seq_lens: torch.Tensor + """same to seq_lens_cpu, for compatibility with some new attn metadata + (such as GDN).""" + + num_computed_tokens_cpu: torch.Tensor + """(batch_size,), the number of computed tokens for each request""" + num_reqs: int """Number of requests""" num_actual_tokens: int @@ -34,7 +45,7 @@ class AscendCommonAttentionMetadata: block_table_tensor: torch.Tensor - slot_mapping_cpu: torch.Tensor + slot_mapping: torch.Tensor actual_seq_lengths_q: list[int] @@ -93,3 +104,34 @@ def split_decodes_and_prefills( num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index f8dfc24..8a41807 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -3,10 +3,12 @@ import dataclasses from contextlib import ExitStack +from dataclasses import dataclass from typing import Any, Callable, Optional from unittest.mock import patch import torch +import torch_npu import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphOptions @@ -15,7 +17,8 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors + +from ..utils import weak_ref_tensors @dataclasses.dataclass @@ -35,10 +38,10 @@ class ACLGraphWrapper: The workflow of this wrapper in the aclgraph dispatching is as follows: 1. At initialization, a runtime mode is assigned to the wrapper (FULL or - PIECEWISE). - 2. At runtime, the wrapper receives a runtime_mode and a + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a batch_descriptor(key) from the forward context and blindly trust them - for aclgraph dispatching. + for aclgraph dispatching. 3. If runtime_mode is NONE or runtime_mode does not match the mode of the wrapper, just call the runnable directly. 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, @@ -47,9 +50,9 @@ class ACLGraphWrapper: Note: ACLGraphWrapper does not store persistent buffers or copy any runtime inputs into that buffers for replay. We assume implementing them - is done outside of the wrapper. That is because we do not make any + is done outside of the wrapper. That is because we do not make any assumption on the dynamic shape (batch size) of the runtime inputs, as a - trade-off for staying orthogonal to compilation logic. Nevertheless, + trade-off for staying orthogonal to compilation logic. Nevertheless, tracing and checking the input addresses to be consistent during replay is guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ @@ -146,6 +149,7 @@ class ACLGraphWrapper: patch("torch.npu.empty_cache", lambda: None)) # mind-exploding: carefully manage the reference and memory. + forward_context.capturing = True with torch.npu.graph(aclgraph, pool=self.graph_pool): # `output` is managed by pytorch's aclgraph pool output = self.runnable(*args, **kwargs) @@ -183,3 +187,74 @@ class ACLGraphWrapper: logger.info_once("Replaying aclgraph") entry.aclgraph.replay() return entry.output + + +def update_attn_params(update_stream, forward_context, runtime_shape): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + # block_table = forward_context.attn_metadata[key].block_tables + seq_lens = forward_context.attn_metadata[key].seq_lens + + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def get_graph_params(): + return _graph_params diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 4ee02e7..3736534 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -20,14 +20,19 @@ from typing import Type, Union from vllm.config import SchedulerConfig +MAX_INT = 2147483647 + @dataclass class AscendSchedulerConfig(SchedulerConfig): enable_chunked_prefill: bool = False + max_long_partial_prefills: int = MAX_INT + long_prefill_token_threshold: int = MAX_INT policy: str = "fcfs" - num_scheduler_steps: int = 1 scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") + enable_pd_transfer: bool = False + decode_max_num_seqs: int = 0 @classmethod def initialize_from_config( @@ -41,10 +46,13 @@ class AscendSchedulerConfig(SchedulerConfig): } # Override default values into original SchedulerConfig scheduler_config["enable_chunked_prefill"] = False + scheduler_config["max_long_partial_prefills"] = None + scheduler_config["long_prefill_token_threshold"] = None scheduler_config["policy"] = "fcfs" - scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") + scheduler_config["enable_pd_transfer"] = False + scheduler_config["decode_max_num_seqs"] = 0 # Override params in original SchedulerConfig with params in ascend_scheduler_config for k, _ in scheduler_config.items(): if hasattr(ascend_scheduler_config, k): @@ -65,20 +73,36 @@ class AscendSchedulerConfig(SchedulerConfig): "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") + # concurrent partial prefills. Default is inf + if self.max_long_partial_prefills is None: + self.max_long_partial_prefills = MAX_INT + self.long_prefill_token_threshold = MAX_INT + + if self.long_prefill_token_threshold is None or \ + self.long_prefill_token_threshold <= 0: + if self.max_model_len is None: + self.long_prefill_token_threshold = MAX_INT + else: + self.long_prefill_token_threshold = \ + max(1, int(self.max_model_len * 0.04)) + + if self.max_long_partial_prefills < 0: + raise ValueError( + f"max_long_partial_prefills must be non-negative, but got " + f"{self.max_long_partial_prefills}") + if self.long_prefill_token_threshold < 0: + raise ValueError( + f"long_prefill_token_threshold must be non-negative, but got " + f"{self.long_prefill_token_threshold}") + if self.policy != "fcfs": raise NotImplementedError( f"currently AscendScheduler only supports fcfs policy, got {self.policy}" ) - if self.is_multimodal_model: - raise NotImplementedError( - "currently AscendScheduler only supports LLM models.") - if self.num_scheduler_steps > 1: - raise NotImplementedError( - "currently AscendScheduler doesn't support multi-step.") if self.send_delta_data: raise NotImplementedError( "currently AscendScheduler doesn't support send_delta_data.") - if self.delay_factor > 0: + if getattr(self, "scheduler_delay_factor", 0) > 0: raise NotImplementedError( "currently AscendScheduler doesn't support scheduler_delay_factor." ) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index f8c7f49..f4c8cc7 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -23,6 +23,7 @@ from vllm.distributed.kv_events import KVEventBatch from vllm.logger import logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.utils import cdiv +from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs @@ -31,13 +32,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - from vllm.v1.core.kv_cache_manager import KVCacheBlocks -else: - KVCacheBlocks = None - class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler @@ -58,6 +52,15 @@ class AscendScheduler(Scheduler): self.scheduled_req_ids: set[str] = set() self.running: list[Request] = [] + self.finished_prefill_reqs: deque[Request] = deque() + enable_pd_transfer = getattr(self.scheduler_config, + 'enable_pd_transfer', False) + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.phase = "" if not enable_pd_transfer else "prefill" + self.decode_max_num_running_reqs = max(self.max_num_running_reqs, + decode_max_num_seqs) + def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: return super().schedule() @@ -66,12 +69,14 @@ class AscendScheduler(Scheduler): scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - req_to_new_block_ids: dict[str, list[list[int]]] = {} - else: - req_to_new_blocks: dict[str, KVCacheBlocks] = {} + req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec decode-related. scheduled_spec_decode_tokens: dict[str, list[int]] = {} @@ -85,9 +90,33 @@ class AscendScheduler(Scheduler): # and put back at the head of the waiting queue later skipped_waiting_requests: deque[Request] = deque() + if self.phase == "prefill": + remaining_running_reqs = [] + for request in self.running: + # move request has finished prefill to finished_prefill_reqs + if request.num_tokens > request.num_prompt_tokens: + self.finished_prefill_reqs.append(request) + else: + remaining_running_reqs.append(request) + self.running = remaining_running_reqs + # all request prefilled, change phase to decode + if not self.waiting and not self.running: + self.phase = "decode" + # Skip long prompt requests in prefill stage. + # long_prefill_budget is float('inf') if not use. + if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0: + long_prefill_budget = float('inf') + long_prefill_token_threshold = float('inf') + else: + long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills + long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold + # Schedule prefill requests first. while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_running_reqs: + if len(self.running) == (self.decode_max_num_running_reqs + if self.phase == "decode" else + self.max_num_running_reqs): + break request = self.waiting[0] @@ -139,6 +168,9 @@ class AscendScheduler(Scheduler): num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + # P/D: loading remote KV, do not allocate for new work. if load_kv_async: assert num_external_computed_tokens > 0 @@ -176,6 +208,17 @@ class AscendScheduler(Scheduler): assert num_new_tokens > 0 blocks = new_computed_blocks.blocks[0] + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0 or len( + encoder_inputs_to_schedule) == 0: + # The request cannot be scheduled. + break + watermark = getattr(self.scheduler_config, "watermark", 0.01) if not self._check_watermark_for_prefill(request, num_new_tokens, blocks, watermark): @@ -183,6 +226,11 @@ class AscendScheduler(Scheduler): skip_cur_request() continue + if num_new_tokens > long_prefill_token_threshold \ + and long_prefill_budget <= 0: + skip_cur_request() + continue + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, @@ -227,26 +275,41 @@ class AscendScheduler(Scheduler): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) - else: - req_to_new_blocks[ - request.request_id] = self.kv_cache_manager.get_blocks( - request.request_id) + + req_to_new_blocks[ + request.request_id] = self.kv_cache_manager.get_blocks( + request.request_id) # Update request info. num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens + if num_new_tokens > long_prefill_token_threshold: + long_prefill_budget -= 1 request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens # Count the number of prefix cached tokens. if request.num_cached_tokens < 0: request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.extendleft(skipped_waiting_requests) + if self.phase == "decode": + while len( + self.running + ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: + request = self.finished_prefill_reqs.popleft() + self.running.append(request) + # If no prefill requests are scheduled, # Schedule decode requests next. if len(self.scheduled_req_ids) == 0: @@ -267,6 +330,16 @@ class AscendScheduler(Scheduler): num_new_tokens = min( num_new_tokens, self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request and ( @@ -322,11 +395,7 @@ class AscendScheduler(Scheduler): # Schedule the request. scheduled_running_reqs.append(request) self.scheduled_req_ids.add(request.request_id) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - req_to_new_block_ids[request.request_id] = ( - new_blocks.get_block_ids()) - else: - req_to_new_blocks[request.request_id] = new_blocks + req_to_new_blocks[request.request_id] = new_blocks num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -342,6 +411,15 @@ class AscendScheduler(Scheduler): scheduled_spec_decode_tokens[request.request_id] = ( request.spec_token_ids) + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + # Record scheduled LoRA requests. if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) @@ -350,7 +428,9 @@ class AscendScheduler(Scheduler): total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 - assert len(self.running) <= self.max_num_running_reqs + assert len( + self.running + ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs) <= len(self.running) @@ -365,67 +445,36 @@ class AscendScheduler(Scheduler): any_request, len(self.running))) # Construct the scheduler output. - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_block_ids[req.request_id]) - for req in scheduled_new_reqs - ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, scheduled_resumed_reqs, - num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_block_ids) - else: - new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) - for req in scheduled_new_reqs - ] + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, scheduled_resumed_reqs, - num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_blocks) + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, scheduled_resumed_reqs, + num_scheduled_tokens, scheduled_spec_decode_tokens, + req_to_new_blocks) scheduled_cached_reqs = cached_reqs_data - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=new_reqs_data, - scheduled_cached_reqs=scheduled_cached_reqs, - num_scheduled_tokens=num_scheduled_tokens, - total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=num_common_prefix_blocks, - # finished_req_ids is an existing state in the scheduler, - # instead of being newly scheduled in this step. - # It contains the request IDs that are finished in between - # the previous and the current steps. - finished_req_ids=self.finished_req_ids, # type: ignore - free_encoder_input_ids=self.encoder_cache_manager. - get_freed_ids(), - structured_output_request_ids={}, - grammar_bitmask=None, - ) - else: - scheduler_output = SchedulerOutput( - scheduled_new_reqs=new_reqs_data, - scheduled_cached_reqs=scheduled_cached_reqs, - num_scheduled_tokens=num_scheduled_tokens, - total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=num_common_prefix_blocks, - # finished_req_ids is an existing state in the scheduler, - # instead of being newly scheduled in this step. - # It contains the request IDs that are finished in between - # the previous and the current steps. - finished_req_ids=self.finished_req_ids, # type: ignore - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), - structured_output_request_ids={}, - grammar_bitmask=None, - ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=scheduled_cached_reqs, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 458b814..26ddd8f 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -26,3 +26,8 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector", "MooncakeConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnectorStoreV1", + "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", + "MooncakeConnectorV1") diff --git a/vllm_ascend/distributed/cpu_offload_connector.py b/vllm_ascend/distributed/cpu_offload_connector.py new file mode 100644 index 0000000..b27595d --- /dev/null +++ b/vllm_ascend/distributed/cpu_offload_connector.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import queue +import threading +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Sequence + +import torch +from vllm.attention import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.utils import logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec + +from vllm_ascend.distributed.cpu_offload_manager.metadata import ( + MetadataServer, MetadataServerProc, MLAConfig) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + + +@dataclass +class ReqMeta: + gpu_block_ids: list[int] + cpu_block_ids: list[int] + num_scheduled_tokens: int + num_computed_tokens: int + num_gpu_computed_tokens: int + num_cpu_computed_tokens: int + + def update(self, other: "ReqMeta"): + self.gpu_block_ids.extend(other.gpu_block_ids) + self.cpu_block_ids.extend(other.cpu_block_ids) + self.num_scheduled_tokens = other.num_scheduled_tokens + self.num_computed_tokens = other.num_computed_tokens + self.num_gpu_computed_tokens = other.num_gpu_computed_tokens + self.num_cpu_computed_tokens = other.num_cpu_computed_tokens + + +@dataclass +class CPUOffloadingConnectorMetadata(KVConnectorMetadata): + requests: dict[str, ReqMeta] + finished_req_ids: set[str] + + +class CPUOffloadingConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + if not vllm_config.cache_config.enable_prefix_caching: + self.connector_scheduler: Optional[ + CPUOffloadingConnectorScheduler] = None + self.connector_worker: Optional[ + CPUOffloadingConnectorWorker] = None + elif role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = CPUOffloadingConnectorScheduler( + vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = CPUOffloadingConnectorWorker(vllm_config) + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + if self.connector_worker is not None: + assert isinstance(connector_metadata, + CPUOffloadingConnectorMetadata) + self.connector_worker.bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + assert self.connector_worker is not None + self.connector_worker.clear_connector_metadata() + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + if self.connector_worker is not None: + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + if self.connector_worker is not None: + self.connector_worker.start_load_kv() + + def wait_for_layer_load(self, layer_name: str) -> None: + if self.connector_worker is not None: + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + + def wait_for_save(self): + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(), None + + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if self.connector_scheduler is not None: + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if self.connector_scheduler is not None: + return self.connector_scheduler.update_state_after_alloc(request) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + if self.connector_scheduler is not None: + return self.connector_scheduler.build_connector_meta( + scheduler_output) + return KVConnectorMetadata() + + def request_finished( + self, request: "Request", + block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]: + if self.connector_scheduler is not None: + self.connector_scheduler.request_finished(request) + return True, None + + +class CPUOffloadingConnectorScheduler: + + def __init__(self, vllm_config: VllmConfig): + logger.info("init CPUOffloadingConnectorScheduler") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.use_mla = vllm_config.model_config.use_mla + self.num_gpu_computed_tokens: dict[str, int] = {} + self.num_cpu_computed_tokens: dict[str, int] = {} + self.allocated_req_ids: set[str] = set() + self.finished_req_ids: list[str] = [] + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.zmq_rpc_client.call("post_init") + if vllm_config.kv_transfer_config is not None: + self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config( + "swap_in_threshold", 0) + else: + self.swap_in_threshold = 0 + logger.info(f"swap_in_threshold: {self.swap_in_threshold}") + + def get_num_new_matched_tokens( + self, ori_request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + request = copy.deepcopy(ori_request) + request.get_hash_new_full_blocks = None + num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call( + "get_matched_num_and_touch", request) + self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens + self.num_cpu_computed_tokens[ + request.request_id] = num_cpu_computed_tokens + if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold: + return num_cpu_computed_tokens - num_computed_tokens, load_async + else: + return 0, load_async + + def update_state_after_alloc(self, request: "Request"): + self.allocated_req_ids.add(request.request_id) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + num_tokens = {} + # process scheduled_new_reqs + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + num_tokens[req_id] = ( + req.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # process scheduled_cached_reqs + cached_reqs = scheduler_output.scheduled_cached_reqs + for idx, req_id in enumerate(cached_reqs.req_ids): + num_tokens[req_id] = ( + cached_reqs.num_computed_tokens[idx] + + scheduler_output.num_scheduled_tokens[req_id]) + + unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() - + self.allocated_req_ids - + scheduler_output.num_scheduled_tokens.keys()) + new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", + num_tokens, + unallocated_req_ids) + metadata = CPUOffloadingConnectorMetadata( + requests={}, + finished_req_ids=set(self.finished_req_ids), + ) + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + gpu_block_ids = req.block_ids[0] + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=req.num_computed_tokens, + num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id], + num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id]) + + for idx, req_id in enumerate(cached_reqs.req_ids): + gpu_block_ids = cached_reqs.new_block_ids[idx] + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=cached_reqs.num_computed_tokens[idx], + num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx], + num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx]) + self.num_gpu_computed_tokens.clear() + self.num_cpu_computed_tokens.clear() + self.allocated_req_ids.clear() + self.finished_req_ids.clear() + return metadata + + def request_finished(self, ori_request: "Request"): + request = copy.deepcopy(ori_request) + request.get_hash_new_full_blocks = None + self.finished_req_ids.append(request.request_id) + # inform metadata server to record request, and free it after finish sending + self.zmq_rpc_client.call("record_request_cache_and_free_slots", + request) + + +class CPUOffloadingConnectorWorker: + + def __init__(self, vllm_config: VllmConfig): + logger.info("init CPUOffloadingConnectorWorker") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.pp_rank = get_pp_group().rank_in_group + self.tp_group = get_tp_group() + self.tp_rank = self.tp_group.rank_in_group + self.tp_world_size = self.tp_group.world_size + self.use_mla = vllm_config.model_config.use_mla + + self.requests: dict[str, ReqMeta] = {} + self.load_stream = torch.npu.Stream() + self.save_stream = torch.npu.Stream() + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.load_block_mapping: list[tuple[int, int]] = [] + self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue() + self.save_output_queue: queue.Queue[str] = queue.Queue() + self.save_thread = threading.Thread(target=self._save_listener) + self.save_thread.start() + self.done_sending_count: defaultdict[str, int] = defaultdict(int) + + # start metadata server to init cpu_kv_cache_manager and handle rpc requests + # all dp shared the same metadata server, only start the process on data_rank 0 + if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0: + config = VllmConfig() + config.cache_config = vllm_config.cache_config + config.parallel_config = vllm_config.parallel_config + config.kv_transfer_config = vllm_config.kv_transfer_config + self.init_metadata_server(config) + self._wait_for_metadata_process_start() + + def init_metadata_server(self, vllm_config: VllmConfig): + self.metadata_thread = threading.Thread( + target=MetadataServerProc.run_metadata_server, + args=(vllm_config, ), + ) + self.metadata_thread.daemon = True + self.metadata_thread.start() + + def _wait_for_metadata_process_start(self): + # TODO: wait for metadata server to start, add a rpc to check if ready + while True: + try: + if self.zmq_rpc_client.call("ready"): + break + except Exception as e: + logger.info(f"wait for metadata server to start, error: {e}") + time.sleep(1) + + def bind_connector_metadata( + self, connector_metadata: CPUOffloadingConnectorMetadata) -> None: + for req_id, req in connector_metadata.requests.items(): + if req_id in self.requests: + self.requests[req_id].update(req) + req = self.requests[req_id] + else: + self.requests[req_id] = req + for i in range(req.num_gpu_computed_tokens // self.block_size, + req.num_computed_tokens // self.block_size): + self.load_block_mapping.append( + (req.cpu_block_ids[i], req.gpu_block_ids[i])) + for req_id in connector_metadata.finished_req_ids: + if req_id in self.requests: + self.save_input_queue.put((req_id, self.requests[req_id])) + + def clear_connector_metadata(self) -> None: + self.load_block_mapping.clear() + + def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]): + self.gpu_kv_caches = kv_caches + model_config = self.vllm_config.model_config + mla_config: Optional[MLAConfig] = None + if model_config.use_mla: + mla_config = MLAConfig( + model_config.hf_text_config.kv_lora_rank, + model_config.hf_text_config.qk_rope_head_dim) + self.cpu_kv_caches = list( + self.zmq_rpc_client.call( + "init_cpu_kv_caches", + self.pp_rank, + self.tp_rank, + get_kv_cache_spec(self.vllm_config), + mla_config, + ).values()) + + def start_load_kv(self) -> None: + self.current_layer = 0 + self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values()) + self.load_kv_layer(0) + + def wait_for_layer_load(self) -> None: + # TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug. + self.load_stream.synchronize() + self.current_layer += 1 + self.load_kv_layer(self.current_layer) + + def load_kv_layer(self, layer: int): + if layer == len(self.gpu_kv_caches): + return + gpu_kv_caches = next(self.gpu_kv_caches_load_iter) + cpu_kv_caches = self.cpu_kv_caches[layer] + with torch.npu.stream(self.load_stream): + for cpu_block_id, gpu_block_id in self.load_block_mapping: + for gpu_layer_part, cpu_layer_part in zip( + gpu_kv_caches, cpu_kv_caches): + gpu_layer_part[gpu_block_id].copy_( + cpu_layer_part[cpu_block_id], non_blocking=True) + + def get_finished(self) -> set[str]: + done_sending: set[str] = set() + while True: + try: + id = self.save_output_queue.get_nowait() + except queue.Empty: + break + done_sending.add(id) + for id in done_sending: + del self.requests[id] + if self.tp_world_size == 1: + return done_sending + if self.tp_rank == 0: + for req_id in done_sending: + self.done_sending_count[req_id] += 1 + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.tp_world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + self.done_sending_count[req_id] += 1 + all_done_sending: set[str] = set() + for req_id in list(self.done_sending_count.keys()): + if self.done_sending_count[req_id] == self.tp_world_size: + del self.done_sending_count[req_id] + all_done_sending.add(req_id) + # release cpu_kv_cache after request sending finished + # to avoid rpc blocking, use thread to call rpc asynchronously + sending_finished_thread = threading.Thread( + target=self._sending_finished, args=(all_done_sending, )) + sending_finished_thread.daemon = True + sending_finished_thread.start() + + return all_done_sending + else: + self.tp_group.send_object(done_sending, dst=0) + return done_sending + + def _sending_finished(self, all_done_sending): + for req_id in all_done_sending: + logger.debug(f"call cache_and_free_slots for req_id: {req_id}") + self.zmq_rpc_client.call("cache_and_free_slots", req_id) + + def _save_listener(self): + save_block_mapping = [] + while True: + req_id, req = self.save_input_queue.get() + for i in range( + req.num_cpu_computed_tokens // self.block_size, + min((req.num_computed_tokens + req.num_scheduled_tokens) // + self.block_size, len(req.cpu_block_ids))): + save_block_mapping.append( + (req.gpu_block_ids[i], req.cpu_block_ids[i])) + with torch.npu.stream(self.save_stream): + # MLA: kv_layer is tuple[tensor, tensor] means (rope, nope). + # non-MLA: kv_layer is list[tensor], typically means [k, v]. + if self.use_mla: + start, step = self.tp_rank, self.tp_world_size + else: + start, step = 0, 1 + for i in range(start, len(save_block_mapping), step): + gpu_block_id, cpu_block_id = save_block_mapping[i] + for cpu_kv_caches, gpu_kv_caches in zip( + self.cpu_kv_caches, self.gpu_kv_caches.values()): + for cpu_layer_part, gpu_layer_part in zip( + cpu_kv_caches, gpu_kv_caches): + cpu_layer_part[cpu_block_id].copy_( + gpu_layer_part[gpu_block_id], + non_blocking=True) + self.save_stream.synchronize() + self.save_output_queue.put(req_id) + save_block_mapping.clear() + + +# Copied from vllm_ascend/worker/model_runner_v1.py. +def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: + forward_ctx = vllm_config.compilation_config.static_forward_context + block_size = vllm_config.cache_config.block_size + use_mla = vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in forward_ctx.items(): + if isinstance(attn_module, FusedMoE): + continue + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + return kv_cache_spec diff --git a/vllm_ascend/ops/layers/__init__.py b/vllm_ascend/distributed/cpu_offload_manager/__init__.py similarity index 100% rename from vllm_ascend/ops/layers/__init__.py rename to vllm_ascend/distributed/cpu_offload_manager/__init__.py diff --git a/vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py b/vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py new file mode 100644 index 0000000..fd68189 --- /dev/null +++ b/vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py @@ -0,0 +1,202 @@ +import time +from collections import defaultdict +from typing import Optional + +from vllm.utils import logger, sha256 +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + PrefixCachingMetrics) +from vllm.v1.core.single_type_kv_cache_manager import \ + get_manager_for_kv_cache_spec +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request + + +class CPUCacheStats: + + def __init__(self, enable_prefix_caching: bool, log_stats: bool = False): + self.enable_prefix_caching = enable_prefix_caching + self.log_stats = log_stats + self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + self.cpu_prefix_cache_metrics = PrefixCachingMetrics() + self.time_sec = int(time.time()) + + def log(self): + current_time_sec = int(time.time()) + # Log the prefix cache hit rate every 10 seconds. + if current_time_sec - self.time_sec >= 10: + self.time_sec = current_time_sec + logger.info("CPU Prefix cache hit rate: %.1f%%", + self.cpu_prefix_cache_metrics.hit_rate * 100) + + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + """Get (and reset) the prefix cache stats. + Returns: + The current prefix caching stats, or None if logging is disabled. + """ + if not self.log_stats: + return None + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + + def update(self, num_tokens, num_computed_tokens): + # Note the function is called by scheduler + if self.log_stats and self.enable_prefix_caching: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += num_tokens + self.prefix_cache_stats.hits += num_computed_tokens + + def set_cache_stats(self, num_tokens, num_computed_tokens): + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.hits = num_computed_tokens + self.prefix_cache_stats.queries = num_tokens + self.prefix_cache_stats.requests = 1 + + +class CPUKVCacheManager: + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + num_cpu_blocks: int, + caching_hash_algo: str = "builtin", + use_eagle: bool = False, + enable_kv_cache_events: bool = False, + ) -> None: + self.block_size = kv_cache_spec.block_size + self.num_cpu_blocks = num_cpu_blocks + self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + self.use_eagle = use_eagle + self.block_pool = BlockPool(self.num_cpu_blocks, True, + enable_kv_cache_events) + self.single_type_manager = get_manager_for_kv_cache_spec( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + kv_cache_group_id=0, + ) + # Record kv block hashes, avoid redundant computation. + self.req_to_block_hashes: defaultdict[ + str, list[BlockHash]] = defaultdict(list) + # Record blocks touched in get_matched_num_and_touch(). + self.req_to_computed_blocks: defaultdict[ + str, list[KVCacheBlock]] = defaultdict(list) + # Record the request that failed to allocate. + self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool) + self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int) + self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, + log_stats=True) + # Record request that will be free after finish sending + self.req_to_free: defaultdict[str, Request] = defaultdict(Request) + + def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]: + # When the request requires prompt logprobs, we skip prefix caching. + if (request.sampling_params.prompt_logprobs is not None): + return 0, False + request_id = request.request_id + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request_id] + if not block_hashes: + block_hashes = request.block_hashes + self.req_to_block_hashes[request_id] = block_hashes + max_cache_hit_length = request.num_tokens - 1 + computed_blocks = self.single_type_manager.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=[0], + block_pool=self.block_pool, + kv_cache_spec=self.single_type_manager.kv_cache_spec, + use_eagle=self.use_eagle, + ) + num_computed_tokens = len(computed_blocks[0]) * self.block_size + self.req_to_computed_blocks[request_id] = computed_blocks[0] + # We should touch these blocks in the concurrent scenarios. + self.block_pool.touch(computed_blocks) + + # cup prefix cache status set and log + assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None + self.cpu_cache_stats.set_cache_stats(request.num_tokens, + num_computed_tokens) + self.cpu_cache_stats.cpu_prefix_cache_metrics.observe( + self.cpu_cache_stats.prefix_cache_stats) + self.cpu_cache_stats.log() + + return num_computed_tokens, False + + def _release_ahead_touch(self, request_id: str): + computed_blocks = self.req_to_computed_blocks[request_id] + if computed_blocks: + self.single_type_manager.block_pool.free_blocks( + reversed(computed_blocks)) + self.req_to_computed_blocks.pop(request_id, None) + + def allocate_slots(self, req_to_num_tokens: dict[str, int], + unallocated_req_ids: set[str]) -> dict[str, list[int]]: + for request_id in unallocated_req_ids: + self._free_slots(request_id) + req_to_new_blocks = {} + for request_id, num_tokens in req_to_num_tokens.items(): + if self.req_failed_to_allocate[request_id]: + continue + new_computed_blocks = self.req_to_computed_blocks[request_id] + num_blocks_to_allocate = ( + self.single_type_manager.get_num_blocks_to_allocate( + request_id=request_id, + num_tokens=num_tokens, + new_computed_blocks=new_computed_blocks, + )) + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + self._release_ahead_touch(request_id) + self.req_failed_to_allocate[request_id] = True + continue + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + self.single_type_manager.save_new_computed_blocks( + request_id, new_computed_blocks) + # Allocate new blocks but do not cache now. + new_blocks = self.single_type_manager.allocate_new_blocks( + request_id, num_tokens) + self.req_to_num_tokens[request_id] = num_tokens + # No need to release ref_cnt because we use officially. + self.req_to_computed_blocks.pop(request_id, None) + req_to_new_blocks[request_id] = [ + block.block_id for block in new_computed_blocks + new_blocks + ] + return req_to_new_blocks + + def record_request_cache_and_free_slots(self, request: Request): + logger.debug( + f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager" + ) + self.req_to_free[request.request_id] = request + + def cache_and_free_slots(self, request_id: str): + logger.debug( + f"Cache and free slots for request {request_id} in cpu_kv_cache_manager" + ) + if request_id not in self.req_to_free: + logger.Error( + f"request {request_id} not in req_to_free, maybe bug!") + return + request = self.req_to_free[request_id] + if not self.req_failed_to_allocate[request_id]: + self.single_type_manager.cache_blocks( + request, + self.req_to_num_tokens[request_id], + ) + self._free_slots(request_id) + logger.debug( + f"delete request {request_id} in cpu_kv_cache_manager req_to_free") + del self.req_to_free[request_id] + + def _free_slots(self, request_id: str): + # This function is designed to be reentrant. + self._release_ahead_touch(request_id) + self.single_type_manager.free(request_id) + self.req_to_block_hashes.pop(request_id, None) + self.req_to_computed_blocks.pop(request_id, None) + self.req_failed_to_allocate.pop(request_id, None) + self.req_to_num_tokens.pop(request_id, None) diff --git a/vllm_ascend/distributed/cpu_offload_manager/metadata.py b/vllm_ascend/distributed/cpu_offload_manager/metadata.py new file mode 100644 index 0000000..ddfd37c --- /dev/null +++ b/vllm_ascend/distributed/cpu_offload_manager/metadata.py @@ -0,0 +1,269 @@ +import math +import os +import pickle +from dataclasses import dataclass +from multiprocessing.shared_memory import SharedMemory +from typing import Any, Callable, Optional + +import torch +import vllm.envs as envs +import zmq +from vllm.config import KVTransferConfig, VllmConfig +from vllm.utils import get_dtype_size, logger, make_zmq_socket +from vllm.v1.kv_cache_interface import AttentionSpec + +from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \ + CPUKVCacheManager + + +@dataclass +class MLAConfig: + nope_dim: int + rope_dim: int + + +def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig: + if vllm_config.kv_transfer_config is not None: + kv_transfer_config = vllm_config.kv_transfer_config + if kv_transfer_config.kv_connector == "CPUOffloadingConnector": + return kv_transfer_config + elif kv_transfer_config.kv_connector == "MultiConnector": + ktcs = kv_transfer_config.kv_connector_extra_config.get( + "connectors") + for ktc in ktcs: + kv_transfer_config = KVTransferConfig(**ktc) + if kv_transfer_config.kv_connector == "CPUOffloadingConnector": + return kv_transfer_config + return None + + +class MetadataServer: + METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc" + DEFAULT_CPU_SWAP_SPACE_GB = 800 + + class ZMQRPCClient: + + def __init__(self, identity=f"worker-{os.getpid()}"): + logger.info(f"metadata client for worker {identity} started") + self.ctx = zmq.Context() # type: ignore + self.socket = make_zmq_socket( + self.ctx, + MetadataServer.METADATA_SERVER_ADDRESS, + zmq.DEALER, # type: ignore + bind=False, + identity=identity.encode(), + linger=0) + + def call(self, func_name: str, *args, **kwargs) -> Any: + request = (func_name, args, kwargs) + self.socket.send(b"", zmq.SNDMORE) # type: ignore + self.socket.send(pickle.dumps(request)) + _ = self.socket.recv() + response = pickle.loads(self.socket.recv()) + result, error = response + if error: + logger.exception(f"call metadata sever error: {error}") + raise error + if func_name == "init_cpu_kv_caches": + (memory_dict, layer_size, layer_dtype, mla_config) = result + # shared_memory_dict is recorded in self to close + self.shared_memory_dict = memory_dict + result = {} + for key, shm in memory_dict.items(): + tensor = torch.frombuffer( + shm.buf, dtype=layer_dtype).reshape(layer_size) + if mla_config is not None: + tensor = tensor.split( + [mla_config.nope_dim, mla_config.rope_dim], dim=-1) + result[key] = tensor + return result + + def __del__(self): + # will be finalized by outer process + self.socket.close() + self.ctx.term() + if hasattr(self, 'shared_memory_dict'): + for shm in self.shared_memory_dict.values(): + shm.close() + + def __init__(self, vllm_config: VllmConfig): + self.world_size = vllm_config.parallel_config.world_size + self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size + kv_transfer_config = get_cpu_offload_connector(vllm_config) + assert kv_transfer_config is not None + available_memory_gb = kv_transfer_config.get_from_extra_config( + "cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB) + self.available_memory = available_memory_gb * 1024 * 1024 * 1024 + logger.info(f"cpu swap space: {self.available_memory} bytes") + self.ctx = zmq.Context() # type: ignore + self.socket = make_zmq_socket( + self.ctx, + MetadataServer.METADATA_SERVER_ADDRESS, + zmq.ROUTER, # type: ignore + bind=True, + linger=0) + self.functions: dict[str, Callable] = { + "init_cpu_kv_caches": self.init_cpu_kv_caches, + "post_init": self.post_init, + "ready": self.ready, + } + self.shared_memory = {} # type: ignore + self.num_cpu_blocks = -1 + + @staticmethod + def _safe_create_shared_memory(name: str, size: int) -> SharedMemory: + try: + existing_shm = SharedMemory(name=name, create=False) + existing_shm.close() + existing_shm.unlink() + except FileNotFoundError: + pass + return SharedMemory(name=name, create=True, size=size) + + def ready(self): + return True + + def init_cpu_kv_caches( + self, + pp_rank: int, + tp_rank: int, + kv_cache_specs: dict[str, AttentionSpec], + mla_config: MLAConfig, + ) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, + MLAConfig]: + logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}") + # follow the assumption that each layer has the same spec + layer = next(iter(kv_cache_specs.values())) + assert all([ + layer.page_size_bytes == any.page_size_bytes + for any in kv_cache_specs.values() + ]) + # mla shares the same kv cache among different tp + if layer.use_mla: + tp_rank = 0 + if (pp_rank, tp_rank) in self.shared_memory: + return self.shared_memory[(pp_rank, tp_rank)] + available_memory = self.available_memory + shared_memory_dict = {} + if layer.use_mla: + available_memory //= self.pipeline_parallel_size + available_memory //= len(kv_cache_specs) + num_blocks = available_memory // layer.page_size_bytes + layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, + layer.head_size) # type: ignore + else: + available_memory //= self.world_size + available_memory //= len(kv_cache_specs) + num_blocks = available_memory // layer.page_size_bytes + layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, + layer.head_size) # type: ignore + nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype) + for layer_name in kv_cache_specs.keys(): + # only this format can share during ZeroMQ+pickle + shared_memory_dict[ + layer_name] = MetadataServer._safe_create_shared_memory( + f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) + if layer.use_mla: + assert mla_config is not None + assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim + self.shared_memory[(pp_rank, + tp_rank)] = (shared_memory_dict, layer_size, + layer.dtype, mla_config) + else: + self.shared_memory[(pp_rank, + tp_rank)] = (shared_memory_dict, layer_size, + layer.dtype, None) + if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks: + self.num_cpu_blocks = num_blocks + self.layer = layer + return self.shared_memory[(pp_rank, tp_rank)] + + def post_init(self): + # different processors in data parallel may call multiple times + if hasattr(self, 'cpu_block_manager'): + return + # do shared_memory() at least once + logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}") + assert self.num_cpu_blocks >= 0 + self.cpu_block_manager = CPUKVCacheManager(self.layer, + self.num_cpu_blocks) + self.functions.update({ + "get_matched_num_and_touch": + self.cpu_block_manager.get_matched_num_and_touch, + "allocate_slots": + self.cpu_block_manager.allocate_slots, + "record_request_cache_and_free_slots": + self.cpu_block_manager.record_request_cache_and_free_slots, + "cache_and_free_slots": + self.cpu_block_manager.cache_and_free_slots, + }) + + def serve_step(self): + client_id = self.socket.recv() + _ = self.socket.recv() + raw_msg = self.socket.recv() + try: + func_name, args, kwargs = pickle.loads(raw_msg) + except Exception as e: + response = (None, Exception(f"Invalid request: {str(e)}")) + else: + if func_name in self.functions: + try: + result = self.functions[func_name](*args, **kwargs) + response = (result, None) # type: ignore + except Exception as e: + logger.exception(f"metadata execute error: {e}") + response = (None, e) # type: ignore + else: + response = (None, NameError(f"Function {func_name} not found")) + self.socket.send(client_id, zmq.SNDMORE) # type: ignore + self.socket.send(b"", zmq.SNDMORE) # type: ignore + self.socket.send(pickle.dumps(response)) + + def shutdown(self): + self.socket.close() + self.ctx.term() + socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace( + "ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + for cached in self.shared_memory.values(): + for shm in cached[0].values(): + shm.close() + shm.unlink() + + +class MetadataServerProc: + + @staticmethod + def run_metadata_server(vllm_config: VllmConfig): + if (not vllm_config.cache_config.enable_prefix_caching + or get_cpu_offload_connector(vllm_config) is None): + return + + shutdown_requested = False + + def _signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the worker + # signal.signal(signal.SIGTERM, _signal_handler) + # signal.signal(signal.SIGINT, _signal_handler) + metadata_server: Optional[MetadataServer] = None + try: + metadata_server = MetadataServer(vllm_config) + logger.info("Metadata server started.") + while True: + metadata_server.serve_step() + except SystemExit: + logger.info("Metadata server exiting.") + raise + except Exception as e: + logger.exception(f"Metadata server error: {e}.") + raise e + finally: + if metadata_server is not None: + metadata_server.shutdown() diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index fe6617a..6169328 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -1,4 +1,5 @@ import contextlib +import copy import json import math import os @@ -17,6 +18,7 @@ import torch import zmq from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, LLMException, LLMRole) +from vllm import envs from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -184,6 +186,7 @@ class LLMDataDistCMgrConnectorScheduler(): self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -248,7 +251,12 @@ class LLMDataDistCMgrConnectorScheduler(): meta.add_new_req(request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params) + + meta.reqs_to_send = copy.deepcopy(self._reqs_need_send) + + # Clear the list once workers start the transfers self._reqs_need_recv.clear() + self._reqs_need_send.clear() return meta @@ -275,6 +283,9 @@ class LLMDataDistCMgrConnectorScheduler(): if delay_free_blocks: logger.info("Delaying free of %d blocks for request %s", len(computed_block_ids), request.request_id) + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -341,6 +352,7 @@ class LLMDataDistCMgrConnectorWorker(): os.environ["HCCL_DETERMINISTIC"] = "true" self.done_receiving_counts: defaultdict[str, set[int]] = defaultdict(set) + self.reqs_to_send: dict[str, float] = {} def listen_for_agent_metadata_req(self, event: threading.Event): assert self.local_agent_metadata is not None @@ -375,16 +387,13 @@ class LLMDataDistCMgrConnectorWorker(): ) elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: finished_req_id = decode_msg[0] - decode_tp_rank = decode_msg[1] - decode_tp_size = decode_msg[2] with self.thread_lock: - if self._increment_task_count(finished_req_id, - decode_tp_rank, - decode_tp_size): - logger.debug( - f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" - ) + logger.debug( + f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" + ) + if finished_req_id in self.reqs_to_send: self.finished_reqs.add(finished_req_id) + del self.reqs_to_send[finished_req_id] sock.send_multipart( (identity, b"", b"receiving decode finished")) else: @@ -392,24 +401,6 @@ class LLMDataDistCMgrConnectorWorker(): f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !" ) - def _increment_task_count(self, request_id: str, tp_rank: int, - decode_tp_size: int): - if request_id not in self.done_receiving_counts: - self.done_receiving_counts[request_id] = set() - if tp_rank in self.done_receiving_counts[request_id]: - logger.warning( - f"Received duplicate done signal for request {request_id} " - f"from tp rank {tp_rank}. Ignoring.") - return False - self.done_receiving_counts[request_id].add(tp_rank) - if len(self.done_receiving_counts[request_id]) == decode_tp_size: - self.done_receiving_counts.pop(request_id) - logger.info("All transfers completed for request: " - f"{request_id}. Total ranks: " - f"{decode_tp_size}.") - return True - return False - def init_llm_datadist(self): assert self.local_agent_metadata is not None llm_config = LLMConfig() @@ -502,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker(): assert self.local_agent_metadata is not None kv_cache_dtype = first_kv_cache.dtype self.use_mla: bool = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sfa: bool = len(first_kv_cache_tuple) == 3 # MLA case. [2 (k_normed, k_pe), num_blocks, ...] + # SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...] # MHA case. [2 (k and v), num_blocks, ...] self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] @@ -549,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" ) + elif self.use_sfa: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + cache_k_idx_addr_list = [] + k_normed = None + k_pe = None + k_idx = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[ + 1], cache_or_caches[2] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + cache_k_idx_addr_list.append(k_idx.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list, + cache_k_idx_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_idx = CacheDesc( + len(self.cache_addr[2]), [*k_idx.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + cache_key_k_idx = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=2) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe, + cache_desc_k_idx) + self.cache_key = (cache_key_k_normed, cache_key_k_pe, + cache_key_k_idx) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + cache_k_idx = self.cache_manager.register_blocks_cache( + self.cache_desc[2], self.cache_addr[2], self.cache_key[2]) + self.cache = (cache_k_normed, cache_k_pe, cache_k_idx) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) else: for cache_or_caches in kv_caches.values(): for cache in cache_or_caches: @@ -605,6 +651,7 @@ class LLMDataDistCMgrConnectorWorker(): for future in futures: future.add_done_callback(handle_exception) + self.reqs_to_send.update(metadata.reqs_to_send) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -767,24 +814,24 @@ class LLMDataDistCMgrConnectorWorker(): cluster_id = self.add_remote_agent(metadata) return cluster_id - def send_finish_to_remote(self, host: str, port: int, request_id): - url = f"tcp://{host}:{port}" - logger.debug(f"Sending finished to remote: {url}") - msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode([ - LLMDataDistCMgrEvent.ReqForFinished, - [request_id, self.tp_rank, self.tp_size] - ]) - with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] - try: - sock.send(msg_send) - logger.debug( - f"Request id {request_id} finished message send to remote {url}" - ) - _ = sock.recv() - except Exception as e: - logger.error( - f"Failed to send reqest_id {request_id} to prefill: {e}") + def send_finish_to_remote(self, host: str, ports: list[int], request_id): + for port in ports: + url = f"tcp://{host}:{port}" + logger.debug(f"Sending finished to remote: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + try: + sock.send(msg_send) + logger.debug( + f"Request id {request_id} finished message send to remote {url}" + ) + _ = sock.recv() + except Exception as e: + logger.error( + f"Failed to send reqest_id {request_id} to prefill: {e}" + ) def _read_blocks( self, @@ -834,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) + elif self.use_sfa: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + remote_cache_key_k_idx = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=2) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_idx, + self.cache[2], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) else: remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id) logger.info("Try pull blocks from remote server") @@ -851,7 +930,10 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) - self.send_finish_to_remote(remote_ip, remote_port, request_id) + remote_ports = list( + range(remote_port + self.tp_rank, + remote_port + int(remote_tp_size), self.tp_size)) + self.send_finish_to_remote(remote_ip, remote_ports, request_id) with self.thread_lock: self.finished_reqs.add(request_id) @@ -859,8 +941,19 @@ class LLMDataDistCMgrConnectorWorker(): self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" - import copy + now = time.perf_counter() with self.thread_lock: + while self.reqs_to_send: + req_id, expires = next(iter(self.reqs_to_send.items())) + if now < expires: + break + logger.warning( + "Some requests in prefill node fail to receive KV Cache transfer done signal. " + "If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " + ) + if req_id in self.reqs_to_send: + self.finished_reqs.add(req_id) + del self.reqs_to_send[req_id] req_ids_to_ret = copy.deepcopy(self.finished_reqs) self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: @@ -891,4 +984,4 @@ def zmq_ctx(socket_type: Any, yield socket finally: if ctx is not None: - ctx.destroy(linger=0) + ctx.destroy(linger=0) \ No newline at end of file diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py deleted file mode 100644 index aa9bae8..0000000 --- a/vllm_ascend/distributed/moe_comm_method.py +++ /dev/null @@ -1,556 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch_npu -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe import FusedMoEConfig - -from vllm_ascend.distributed.communication_op import \ - data_parallel_reduce_scatter -from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version - - -class MoECommMethod(ABC): - """Base class for MoE communication methods.""" - - def __init__(self, moe_config: FusedMoEConfig): - self.moe_config = moe_config - - @abstractmethod - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare the MoE communication method. - - This method is called before quant_method.apply to prepare the - communication method. It can be used to initialize any necessary - resources or configurations. - """ - pass - - @abstractmethod - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """Finalize the MoE communication method. - - This method is called after quant_method.apply to finalize the - communication method. It can be used to clean up any resources or - configurations. - """ - pass - - @abstractmethod - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - """Pre-process before MLP. - - Args: - hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size) - topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num) - topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num) - expert_map (torch.Tensor): Tensor of shape (global_num_experts, ) - Mapping from global expert IDs to local expert IDs. - num_experts (int): Number of local experts (experts on this device). - apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8). - - Returns: - tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing: - - permuted_hidden_states (torch.Tensor): Tensor of shape - (num_tokens * top_k_num, hidden_size) after permuting - hidden_states based on topk_ids. - - expert_tokens (torch.Tensor): Tensor of shape (num_experts, ) - Number of tokens assigned to each expert. - - dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, ) - Dynamic scale for each expert, used for quantization. - - group_list_type (int): Type of group list, 0 for `cumsum` - and 1 for `count`. This is mainly for `npu_grouped_matmul` - to determine how to handle the output. - Raises: - NotImplementedError: If the method is not implemented in the subclass. - """ - pass - - @abstractmethod - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - """Post-process after MLP. - - Args: - mlp_output (torch.Tensor): Tensor of shape - (num_tokens * top_k_num, hidden_size) after MLP. - hidden_states (torch.Tensor): Tensor of shape - (num_tokens, hidden_size) to be updated with the final output. - """ - pass - - -class AllGatherCommImpl(MoECommMethod): - """This implementation is the same as NativeAllGatherCommImpl, - but uses NPU-specific ops for better performance. - - This implementation should be compatible with all scenarios, and - thus it is the default implementation for MoE communication methods. - It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing - and `torch_npu.npu_moe_token_unpermute` for post-processing - to handle the token-to-expert mapping and communication efficiently. - - NOTE(Yizhou): TBH, it is really weird that we were supposed to use - `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` - or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` - for pre-processing and post-processing, respectively. - But `npu_moe_finalize_routing` will lead to accuracy issues so we have to - use `torch_npu.npu_moe_token_unpermute` instead. - This is a workaround and should be removed after the issue is fixed. - """ - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """When DP size > 1, pad the hidden states and router logits for communication.""" - if self.moe_config.dp_size > 1: - forward_context = get_forward_context() - max_tokens_across_dp = forward_context.max_tokens_across_dp - - self.num_tokens = hidden_states.shape[0] - pad_size = max_tokens_across_dp - self.num_tokens - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - hidden_states = self.moe_config.dp_group.all_gather( - hidden_states, 0) - router_logits = self.moe_config.dp_group.all_gather( - router_logits, 0) - - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """When DP size > 1, reduce-scatter the hidden states to get the final output. - - When TP size > 1, all-reduce the hidden states to get the final output. - """ - if self.moe_config.dp_size > 1: - hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0) - hidden_states = hidden_states[:self.num_tokens] - - if reduce_results and (self.moe_config.tp_size > 1 - or self.moe_config.ep_size > 1): - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, # noqa: F841 - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - num_tokens = hidden_states.shape[0] - - self.topk_weights = topk_weights - self.topk_ids = topk_ids - - first_expert_idx = 0 - if expert_map is not None: - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - mask = expert_map[topk_ids] != -1 - # NOTE: This is equivalent to self.topk_weights[~mask] = 0.0, - # but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph - self.topk_weights = torch.where(mask, topk_weights, 0.0) - - first_expert_idx = self.moe_config.ep_rank * num_experts - last_expert_idx = first_expert_idx + num_experts - - permuted_hidden_states, expanded_row_idx, expert_tokens, _ = ( - torch_npu.npu_moe_init_routing_v2( - hidden_states, - topk_ids, - active_num=num_tokens * self.moe_config.experts_per_token, - expert_num=self.moe_config.num_experts, - expert_tokens_num_type=1, # Only support `count` mode now - expert_tokens_num_flag=True, # Output `expert_tokens` - active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=-1, - )) - self.expanded_row_idx = expanded_row_idx - permuted_hidden_states = permuted_hidden_states - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, None, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - hidden_states[:] = torch_npu.npu_moe_token_unpermute( - permuted_tokens=mlp_output, - sorted_indices=self.expanded_row_idx, - probs=self.topk_weights) - - -class NativeAllGatherCommImpl(AllGatherCommImpl): - """This implementation should be compatible with all scenarios. - - Note that this implementation purely consists of native PyTorch ops - and does not use any NPU-specific ops. So the performance may not be optimal. - But it is a good fallback for scenarios where NPU-specific ops are not available. - """ - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - num_tokens = hidden_states.shape[0] - - # Generate token indices and flatten - token_indices = torch.arange(num_tokens, - device=hidden_states.device, - dtype=torch.int64) - token_indices = (token_indices.unsqueeze(1).expand( - -1, self.moe_config.experts_per_token).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = (expert_map[experts_flat] - if expert_map is not None else experts_flat) - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - filtered_weights = torch.where(mask, weights_flat, - torch.zeros_like(weights_flat)).to( - topk_weights.dtype) - filtered_experts = torch.where( - mask, - local_experts_flat, - torch.full_like(local_experts_flat, num_experts), - ).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - self.sorted_token_indices = token_indices[sort_indices] - self.sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=hidden_states.device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - expert_tokens = token_counts[:num_experts] - - # Rearrange hidden_states - permuted_hidden_states = hidden_states[self.sorted_token_indices] - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, None, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros_like(hidden_states) - final_hidden_states.index_add_(0, self.sorted_token_indices, - mlp_output) - - hidden_states[:] = final_hidden_states - - -class MC2CommImpl(MoECommMethod): - """This implementation is for the scenarios listed below: - 1. `enable_expert_parallel=True`. - 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. - 3. `enable_expert_parallel=False` is not supported. - - This implementation uses the MC2 communication method, which is optimized for - Communication and Computation parallelism on Ascend devices. - """ - - def __init__(self, moe_config: Optional[FusedMoEConfig]): - super().__init__(moe_config) - - # NOTE: We do not need to use mc2_group's rank and world size - # because ep_group and mc2_group basically have the same init params. - # We only init another group because of the restriction of MC2: - # "No other groups can be used in the same process as the MC2 group." - self.mc2_comm_name = get_mc2_group().device_group._get_backend( - torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank) - - # Feature flags - self.enable_dispatch_v2 = hasattr(torch_npu, - "npu_moe_distribute_dispatch_v2") - self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3 - self.need_extra_args = self.is_ascend_a3 - self._restore_tp_across_dp() - - def _restore_tp_across_dp(self): - # NOTE: Since vLLM flatten tp across dp, we need to restore the original - # tp_size and tp_rank. - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """The target_pad_length is calculated in forward_context, here we pad the - hidden states and router logits. And if TP size > 1, we also need to split - the tensors accordingly. - """ - self.num_tokens, _ = hidden_states.shape - forward_context = get_forward_context() - self.mc2_mask = forward_context.mc2_mask - target_pad_length = forward_context.padded_num_tokens - pad_size = target_pad_length - self.num_tokens - - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - split_mc2_mask = torch.tensor_split(self.mc2_mask, - self.tp_size, - dim=0) - self.split_hidden_states = split_hidden_states - - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - self.mc2_mask = split_mc2_mask[self.tp_rank] - - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """If TP size > 1, all-gather the hidden states to get the final output. - - Also, unpad the hidden states if needed. - """ - if self.tp_size > 1: - dist.all_gather(list(self.split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(self.split_hidden_states, dim=0) - - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - # Store tensors needed for post_process - self.topk_ids = topk_ids - self.topk_weights = topk_weights.to(torch.float32) - - dispatch_kwargs = { - "x": hidden_states, - "expert_ids": self.topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": self.moe_config.num_experts, - "global_bs": 0, - "scales": None, - "quant_mode": 2 if apply_a8_quantization else 0, - "group_ep": self.mc2_comm_name, - "ep_world_size": self.moe_config.ep_size, - "ep_rank_id": self.moe_config.ep_rank, - } - - if self.need_extra_args: - dispatch_kwargs.update({ - "group_tp": self.mc2_comm_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if self.is_ascend_a3 and self.enable_dispatch_v2: - dispatch_kwargs.update({ - "x_active_mask": self.mc2_mask, - }) - - dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch - - ( - permuted_hidden_states, - dynamic_scale, - self.assist_info_for_combine, - expert_tokens, - self.ep_recv_counts, - self.tp_recv_counts, - ) = dispatch(**dispatch_kwargs)[:6] - - group_list_type = 1 - - return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - combine_kwargs = { - "expand_x": mlp_output, - "expert_ids": self.topk_ids, - "expert_scales": self.topk_weights, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": self.moe_config.num_experts, - "global_bs": 0, - "ep_send_counts": self.ep_recv_counts, - "group_ep": self.mc2_comm_name, - "ep_world_size": self.moe_config.ep_size, - "ep_rank_id": self.moe_config.ep_rank, - } - - if self.enable_dispatch_v2: - combine_kwargs[ - "assist_info_for_combine"] = self.assist_info_for_combine - else: - combine_kwargs["expand_idx"] = self.assist_info_for_combine - - if self.need_extra_args: - combine_kwargs.update({ - "tp_send_counts": self.tp_recv_counts, - "group_tp": self.mc2_comm_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if self.is_ascend_a3 and self.enable_dispatch_v2: - combine_kwargs.update({ - "x_active_mask": self.mc2_mask, - }) - - combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine - - hidden_states[:] = combine(**combine_kwargs) - - -class AlltoAllCommImpl(MoECommMethod): - """This implementation is for the scenarios listed below: - 1. `enable_expert_parallel=True`. - 2. `npu_grouped_matmul` is available. - - This implementation uses all-to-all communication to exchange tokens - between data parallel ranks before and after the MLP computation. It should - have better performance than AllGatherCommImpl when DP size > 1. - """ - - def __init__(self, moe_config: Optional[FusedMoEConfig]): - super().__init__(moe_config) - from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - get_token_dispatcher - self.token_dispatcher = get_token_dispatcher( - "TokenDispatcherWithAll2AllV") - self._restore_tp_across_dp() - - def _restore_tp_across_dp(self): - # NOTE: Since vLLM flatten tp across dp, we need to restore the original - # tp_size and tp_rank. - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - self.num_tokens, _ = hidden_states.shape - pad_size = self.tp_size - self.num_tokens - - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - self.split_hidden_states = split_hidden_states - - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """If TP size > 1, all-gather the hidden states to get the final output. - - Also, unpad the hidden states if needed. - """ - if self.tp_size > 1: - dist.all_gather(list(self.split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(self.split_hidden_states, dim=0) - - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - results = self.token_dispatcher.token_dispatch( - hidden_states, - topk_weights, - topk_ids, - None, - log2phy=None, - with_quant=apply_a8_quantization) - return results["hidden_states"], results["group_list"], results[ - "dynamic_scale"], results["group_list_type"] - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - hidden_states[:] = self.token_dispatcher.token_combine(mlp_output) diff --git a/vllm_ascend/ops/moe_dispatcher/__init__.py b/vllm_ascend/distributed/mooncake/__init__.py similarity index 100% rename from vllm_ascend/ops/moe_dispatcher/__init__.py rename to vllm_ascend/distributed/mooncake/__init__.py diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py new file mode 100644 index 0000000..abb3c9e --- /dev/null +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -0,0 +1,447 @@ +import array +import hashlib +import json +import os +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.utils import cdiv, logger +from vllm.v1.core.sched.output import NewRequestData + + +@dataclass +class MooncakeEngineMetadata: + """name of the LLM model""" + + model_name: str + """ world size when running under a distributed setting """ + world_size: int + """ worker id when running under a distributed setting """ + worker_id: int + """ the format of kv tensors """ + kv_dtype: torch.dtype + """ the shape of kv tensors """ + """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ + kv_shape: tuple[int, int, int, int, int] + block_size: int = 128 + """ whether use MLA""" + use_mla: bool = False + + +@dataclass(order=True) +class MooncakeEngineKey: + model_name: str + world_size: int + worker_id: int + chunk_hash: str + + def __hash__(self): + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + )) + + def to_string(self): + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}") + + def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: + """Split the key into multiple keys for each layer""" + keys = [] + for layer_id in range(num_layers): + keys.append( + LayerMooncakeEngineKey( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + layer_id, + )) + return keys + + def to_dict(self): + # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. + return { + "__type__": "CacheEngineKey", + "model_name": self.model_name, + "world_size": self.world_size, + "worker_id": self.worker_id, + "chunk_hash": self.chunk_hash, + } + + @staticmethod + def from_dict(d): + return MooncakeEngineKey( + model_name=d["model_name"], + world_size=d["world_size"], + worker_id=d["worker_id"], + chunk_hash=d["chunk_hash"], + ) + + +@dataclass(order=True) +class LayerMooncakeEngineKey(MooncakeEngineKey): + """A key for the layer cache engine""" + + layer_id: int + + def __hash__(self): + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + self.layer_id, + )) + + def to_string(self): + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") + + +class ChunkedTokenDatabase(): + + def __init__( + self, + metadata: MooncakeEngineMetadata, + ): + self.metadata = metadata + + def _make_key_by_hash(self, + chunk_hash: str, + layer_id: Optional[int] = None): + assert self.metadata is not None + return MooncakeEngineKey( + self.metadata.model_name, + self.metadata.world_size, + self.metadata.worker_id, + chunk_hash, + ) + + def _hash( + self, + tokens: Union[torch.Tensor, List[int]], + prefix_hash: str, + ) -> str: + # TODO: change it to a more efficient hash function + if isinstance(tokens, torch.Tensor): + tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() + elif isinstance(tokens, list): + tokens_bytes = array.array("I", tokens).tobytes() + return hashlib.sha256(prefix_hash.encode("ascii") + + tokens_bytes).hexdigest() + + def _chunk_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + ) -> Iterable[Union[torch.Tensor, List[int]]]: + """ + Chunk the tokens into chunks of size self.metadata.block_size. + + :param tokens: the input tokens, with shape [seq_len] + device: the target device after chunking + + :return: a generator of chunks of tokens, each with + shape [metadata.block_size] + """ + for i in range(0, len(tokens), self.metadata.block_size): + yield tokens[i:i + self.metadata.block_size] + + def _prefix_hash( + self, + token_chunks: Iterable[Union[torch.Tensor, List[int]]], + ) -> Iterable[str]: + prefix_hash = '' + for token_chunk in token_chunks: + prefix_hash = self._hash(token_chunk, prefix_hash) + yield prefix_hash + + def process_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + mask: Optional[torch.Tensor] = None, + ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: + """Process the tokens and return the corresponding cache engine keys. + + :param Union[torch.Tensor, List[int]] tokens: The tokens to process. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched, + and the Falses will ALWAYS be at the PREFIX of the tensor. + + :param bool make_key: Whether to make the cache engine key or not. + If False, the hash value will be returned instead. + + :returns: A iterable of tuples with three elements. The first element + is the start index of the tokens for the key. The second element + is the end index of the tokens for the key. The third element is + the cache engine key (or hash) for the tokens. + + :raises: ValueError if the number of Falses in the mask is not a + multiple of the chunk size. + """ + if mask is not None: + num_falses = mask.numel() - mask.long().sum().item() + else: + num_falses = 0 + + if num_falses % self.metadata.block_size != 0: + raise ValueError( + "The number of Falses in the mask is not a multiple of the chunk size." + ) + total_len = len(tokens) + + token_chunks = self._chunk_tokens(tokens) + prefix_hashes = self._prefix_hash(token_chunks) + + start_idx = 0 + for chunk_id, hash_val in enumerate(prefix_hashes): + start_idx = chunk_id * self.metadata.block_size + end_idx = min(start_idx + self.metadata.block_size, total_len) + if start_idx < num_falses: + continue + else: + yield start_idx, end_idx, self._make_key_by_hash(hash_val) + + +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in mooncake + mooncake_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class SaveSpec: + # Skip already saved tokens + skip_leading_tokens: int + # Whether the scheduler allow us to save the tokens + can_save: bool + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # The token ids that has been scheduled so far + token_ids: list[int] + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been savd + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + + """ + # vLLM 0.9.0 update: request.block_ids changed from list[int] to + # list[list[int]] + # Need to check the type of request.block_ids + + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + unfolded_block_ids = new_request.block_ids[0].copy() + + return RequestTracker( + req_id=new_request.req_id, + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[tuple[list[int], ...], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_ids.extend(new_token_ids) + + if len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_ids: torch.Tensor + + block_ids: list[int] + # # Slot mapping if exchange for block_id + # slot_mapping: torch.Tensor + # Skip save or not + save_spec: Optional[SaveSpec] = None + # load_spec + load_spec: Optional[LoadSpec] = None + + is_last_chunk: Optional[bool] = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + load_spec: Optional[LoadSpec] = None, + skip_save: Optional[bool] = False, + is_last_chunk: Optional[bool] = None, + discard_partial_chunks: bool = True, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + skip_save (bool): whether to skip the save operation. + discard_partial_chunks (bool): whether to discard partial chunks. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_ids = tracker.token_ids + input_token_len = len(input_token_ids) + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + skip_leading_tokens = tracker.num_saved_tokens + chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * + block_size if discard_partial_chunks else 0) + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ((input_token_len // block_size * block_size) + if discard_partial_chunks else input_token_len) + + skip_save = skip_save or num_tokens_to_save < chunk_boundary + if skip_save and load_spec is None: + return None + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + save_spec = SaveSpec(skip_leading_tokens, not skip_save) + + # Calculate the token ids and slot mappings for load and save + # OPTIMIZATION: pre-allocate the buffer for token ids and block ids + token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] + + # # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.mooncake_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + logger.debug( + f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" + ) + return ReqMeta( + req_id=tracker.req_id, + token_ids=token_ids, + block_ids=tracker.allocated_block_ids, + save_spec=save_spec, + load_spec=load_spec, + is_last_chunk=is_last_chunk, + ) + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + + def __init__(self, unfinished_request_ids): + self.requests = [] + self.unfinished_request_ids = unfinished_request_ids + + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +@dataclass +class LasyerMultiBlockReqMeta: + req_id: str + keys: List[LayerMooncakeEngineKey] + starts: List[int] + ends: list[int] + block_ids: list[int] + layer_id: int + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> "MooncakeStoreConfig": + with open(file_path) as file: + config = json.load(file) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", 3355443200), + local_buffer_size=config.get("local_buffer_size", 1073741824), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address")) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + config_path = os.getenv("MOONCAKE_CONFIG_PATH") + if not config_path: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py new file mode 100644 index 0000000..dee5101 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -0,0 +1,251 @@ +import queue +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional + +import torch +from vllm.utils import logger + +from vllm_ascend.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta) +from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore + + +class KVTransferThread(threading.Thread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, name: str): + super().__init__(daemon=True, name=name) + self.tp_rank = tp_rank + self.tp_size = tp_size + self.m_store = m_store + self.ready_event = ready_event + self.kv_caches_base_addr = local_kv_caches_base_addr + self.block_len = block_len + self.token_database = token_database + self.block_size = block_size + self.done_task_lock = threading.Lock() + # TODO(jianzs): find a better way to detect MLA. + self.use_mla = len(block_len) == 2 + + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + self.finished_requests: set[str] = set() + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], + layer_id: int): + block_id = block_ids[start // self.block_size] + if self.use_mla: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[1] + length_k = int(self.block_len[0] / self.block_size * (end - start)) + length_v = int(self.block_len[1] / self.block_size * (end - start)) + size_list = [length_k, length_v] + else: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[0] + length = int(self.block_len[0] / self.block_size * (end - start)) + size_list = [length, length] + addr_list = [addr_k, addr_v] + return addr_list, size_list + + def add_request( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + is_last_chunk: Optional[bool] = None, + ) -> torch.Tensor: + req = ({ + "req_id": req_id, + "tokens": tokens, + "block_ids": block_ids, + "mask": mask, + "is_last_chunk": is_last_chunk, + }) + self.request_queue.put(req) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + def set_finished_request(self, req_id): + with self.done_task_lock: + self.finished_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + pass + + +class KVCacheStoreSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheSendingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + is_last_chunk = req_meta["is_last_chunk"] + torch.npu.current_stream().synchronize() + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.put(key, addr, size) + if is_last_chunk: + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreRecvingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.get(key, addr, size) + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + num_layers: int): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerSendingThread") + self.final_layer_id = num_layers - 1 + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + torch.npu.current_stream().synchronize() + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) + self.m_store.put(key, addr, size) + if req_meta.layer_id == self.final_layer_id: + self.set_finished_request(req_meta.req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + get_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerRecvingThread") + self.get_event = get_event + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) + self.m_store.get(key, addr, size) + self.request_queue.task_done() + self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py new file mode 100644 index 0000000..d89dcd7 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -0,0 +1,489 @@ +# Standard +import math +import threading +import time +from typing import Generator, List, Optional, Union + +# Third Party +import torch +from vllm.config import VllmConfig +from vllm.utils import get_kv_cache_torch_dtype, logger + +from vllm_ascend.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, + MooncakeEngineMetadata) +from vllm_ascend.distributed.mooncake.kv_transfer import ( + KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, + KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) +from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore + + +class MooncakeEngine: + #The main class for the cache engine. + + def __init__( + self, + vllm_config: VllmConfig, + use_layerwize: bool, + ): + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.use_mla = False + if (hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla): + self.use_mla = True + self.use_layerwise = use_layerwize + self.tp_rank = parallel_config.rank + self.tp_size = parallel_config.tensor_parallel_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.block_size = vllm_config.cache_config.block_size + self.current_layer = 0 + # self.use_mla = first_kv_cache_tuple[0].size( + # -1) != first_kv_cache_tuple[1].size(-1) + self.num_layers = model_config.get_num_layers(parallel_config) + self.block_size = vllm_config.cache_config.block_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_dtype = get_kv_cache_torch_dtype( + vllm_config.cache_config.cache_dtype, model_config.dtype) + self.hidden_dim_size = num_kv_head * head_size + if self.use_mla: + kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) + else: + kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, + head_size) + self.metadata = MooncakeEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + kv_dtype, + kv_shape, + self.block_size, + self.use_mla, + ) + + self.token_database = ChunkedTokenDatabase(self.metadata) + + self.m_store = Mooncakestore(parallel_config) + + self.kv_send_thread: Optional[KVTransferThread] = None + self.kv_recv_thread: Optional[KVTransferThread] = None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + + # TODO(tms): Find a more robust way to detect and handle MLA + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + self.m_store.set_kv_caches(kv_caches.values()) + self.kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + + if self.use_layerwise: + self.get_event = threading.Event() + if self.kv_role in ['kv_producer', 'kv_both']: + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreLayerSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending, + self.num_layers) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreLayerRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event, self.get_event) + self.kv_recv_thread.start() + ready_event.wait() + else: + if self.kv_role in ['kv_producer', 'kv_both']: + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event) + self.kv_recv_thread.start() + ready_event.wait() + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + self.current_layer = 0 + self.layerwise_retrievers = [] + for request in metadata.requests: + load_spec = request.load_spec + if load_spec is None or not load_spec.can_load: #load =0 + continue + tokens = request.token_ids + req_id = request.req_id + if (load_spec.mooncake_cached_tokens % self.block_size + != 0) and (load_spec.mooncake_cached_tokens + == tokens.shape[0] - 1): + tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] + else: + tokens = tokens[:request.load_spec.mooncake_cached_tokens] + masked_token_count = (request.load_spec.vllm_cached_tokens // + self.block_size * self.block_size) + token_mask = torch.ones_like(tokens, dtype=torch.bool) + token_mask[:masked_token_count] = False + if self.use_layerwise: + layerwise_retriever = self.retrieve_layer( + req_id, + tokens, + request.block_ids, + token_mask, + ) + next(layerwise_retriever) # first layer load + self.layerwise_retrievers.append(layerwise_retriever) + else: + self.kv_recv_thread.add_request( # type: ignore[union-attr] + req_id, + tokens, + request.block_ids, + token_mask, + ) + + def wait_for_layer_load(self) -> None: + """MooncakeConnector does not do layerwise saving.""" + for layerwise_retriever in self.layerwise_retrievers: + ret_token_mask = next(layerwise_retriever) + if self.current_layer == self.num_layers - 1: + assert ret_token_mask is not None + num_retrieved_tokens = ret_token_mask.sum().item() + logger.info(f"Retrieved {num_retrieved_tokens} tokens") + + def save_kv_layer(self, + connector_metadata: MooncakeConnectorMetadata) -> None: + """MooncakeConnector does not save explicitly.""" + if self.current_layer == 0: + self.layerwise_storers = [] + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + # TODO: whether need to remov saveThread + # no lookup, skipmask + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) + continue # skip this request + + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + layerwise_storer = self.store_layer( + req_id, + token_ids, + mask=store_mask, + block_ids=request.block_ids, + ) + self.layerwise_storers.append(layerwise_storer) + for layerwise_storer in self.layerwise_storers: + try: + next(layerwise_storer) + except Exception: + raise + self.current_layer = self.current_layer + 1 + + def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): + """MooncakeConnector does not save explicitly.""" + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) + continue # skip this request + + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + self.kv_send_thread.add_request( # type: ignore[union-attr] + req_id, + token_ids, + request.block_ids, + store_mask, + request.is_last_chunk, + ) + + def retrieve_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[Optional[torch.Tensor], None, None]: + """ + Retrieve the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the KV transfer which + will be passed into the npu_transfer. + + return: A generator that yields Optional[torch.Tensor]. The tensor will + be the boolean mask indicating which tokens are retrieved and will + only be returned in the last iteration. + """ + + if mask is not None: + num_required_tokens = torch.sum(mask).item() + else: + num_required_tokens = len(tokens) + + ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") + + starts = [] + ends = [] + keys = [] + first_flag = True + for start, end, key in self.token_database.process_tokens( + tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) + ret_mask[start:end] = True + + if keys: + # Transpose the keys into layer major format + keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + if not first_flag: + is_finish = self.get_event.wait(timeout=3) #try---cache + if not is_finish: + logger.info("Layerwise get failed") + self.get_event.clear() + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) + self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] + first_flag = False + yield None + else: + # If no cache are found, we still need to yield to avoid + # `StopIteration` + for layer_id in range(self.num_layers): + yield None + + retrieved_tokens = torch.sum(ret_mask) + logger.debug(f"Retrieved {retrieved_tokens} " + f"out of {num_required_tokens} " + f"out of total {len(tokens)} tokens") + + yield ret_mask + + def store_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[None, None, None]: + """ + Store the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the storage backend which + will be passed into the gpu_connector. + + return: A generator that yields None. In the first iteration, the + generator allocates the memory objects for all layers and moves + the KV cache of the first layer from GPU to CPU. In the next + iterations, it moves the KV cache of layer i from GPU to the memory + objects (on CPU) and puts the memory objects of layer i-1 to the + storage backends. In the last iteration, it puts the memory objects + of the last layer to the storage backends. + """ + + if mask is not None: + num_stored_tokens = torch.sum(mask).item() + else: + num_stored_tokens = len(tokens) + + starts = [] + ends = [] + keys = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) #[block_num,layer_num] + + if keys: + keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) + self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] + yield + else: + for layer_id in range(self.num_layers): + yield + logger.debug( + f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) + done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) + + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), + self.tp_rank) + return done_sending, done_recving + + def wait_layer_transfer_finish(self): + time.sleep(10) + pass + + def lookup( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + + :param tokens: the input tokens, with shape [seq_len] + + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + + for start, end, key in self.token_database.process_tokens(tokens): + try: + if use_layerwise: + keys = [] + keys_multi_layer = key.split_layers(self.num_layers) + for key in keys_multi_layer: + keys.append(key.to_string()) + # batch is_exists + ress = self.m_store.batch_exists(keys) + res = 1 + for value in ress: + if value != 1: + res = 0 + break + else: + res = self.m_store.exists(key) + if res == 1: + continue + else: + return start + except Exception as e: + logger.warning(f"Remote connection failed in contains: {e}") + return start + + # all tokens where found, return the maximal end + return end + + def close(self) -> None: + """Close the cache engine and free all the resources""" + self.m_store.close() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py new file mode 100644 index 0000000..2383749 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -0,0 +1,88 @@ +# Standard +import os + +# Third Party +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.utils import logger + +from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey + +from .config_data import MooncakeStoreConfig + +METADATA_BYTES_LEN = 24 +BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) + + +class Mooncakestore(): + + def __init__(self, parallel_config: ParallelConfig): + try: + from mooncake.store import MooncakeDistributedStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + tp_rank = get_tensor_model_parallel_rank() + tp_size = parallel_config.tensor_parallel_size + dp_rank = parallel_config.data_parallel_rank_local + all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) + if not all_device_ids: + device_ids_list = list( + range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) + else: + device_ids_list = list(map(int, all_device_ids.split(','))) + assert len(device_ids_list) > tp_rank + device_id = device_ids_list[tp_rank] + self.config = MooncakeStoreConfig.load_from_env() + if self.config.protocol == "ascend": + local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \ + ":npu_" + str(device_id) + else: + local_hostname = self.config.local_hostname + self.store = MooncakeDistributedStore() + ret = self.store.setup(local_hostname, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) + if ret != 0: + msg = "Initialize mooncake failed." + logger.error(msg) + raise RuntimeError(msg) + + def set_kv_caches(self, kvcache): + self.kvcache = list(kvcache) + + def exists(self, key: MooncakeEngineKey) -> bool: + return self.store.is_exist(key.to_string()) == 1 + + def batch_exists(self, keys: list[str]) -> list[bool]: + return self.store.batch_is_exist(keys) + + def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + expect_res = sum(size) + key_str = key.to_string() + try: + res = self.store.batch_get_into_ascend(key_str, addr, size) + if res[0] != expect_res: + logger.error(f"Failed to get key: [{key_str}] .") + except Exception: + logger.error(f"Failed to get key: [{key_str}] .") + return res + + def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + key_str = key.to_string() + try: + ret = self.store.batch_put_from_ascend(key_str, addr, size) + if ret[0] != 0: + logger.error(f"Failed to put key {key_str}.") + except Exception: + logger.error(f"Failed to put key {key_str}.") + + return ret + + def close(self): + self.store.close() + logger.info("Closed the mooncake store connection") \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py new file mode 100644 index 0000000..6254e47 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -0,0 +1,484 @@ +import threading +from typing import Any, Optional + +import torch +import vllm.envs as envs +import zmq +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.forward_context import ForwardContext +from vllm.utils import logger, make_zmq_socket +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + +from vllm_ascend.distributed.mooncake.config_data import ( + LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) +from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine + + +class MooncakeConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "use_layerwise", False) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + self.sended_but_unfinished_reqs: set[str] = set() + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( + vllm_config, self.use_layerwise) + else: + self.connector_worker = MooncakeEngine( + vllm_config, + self.use_layerwise, + ) + + assert self.connector_worker is not None + if vllm_config.parallel_config.rank == 0: + self.lookup_server = MooncakeLookupServer( + self.connector_worker, vllm_config, self.use_layerwise) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._get_connector_metadata(), + MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._get_connector_metadata()) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeStoreConnector does not do layerwise saving.""" + if not self.use_layerwise: + return + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeStoreConnector does not save explicitly.""" + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + self.connector_worker.save_kv_layer(self._get_connector_metadata()) + + def wait_for_save(self): + """MooncakeStoreConnector does not save explicitly.""" + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + self.connector_worker.wait_layer_transfer_finish() + return + + self.connector_worker.wait_for_save(self._get_connector_metadata()) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + meta = self._get_connector_metadata() + done_sending, done_recving = self.connector_worker.get_finished() + sended_and_finished: set[str] = set() + for item in list(self.sended_but_unfinished_reqs): + if item not in meta.unfinished_request_ids: + sended_and_finished.add(item) + self.sended_but_unfinished_reqs.remove(item) + for item in done_sending: + if item in meta.unfinished_request_ids: + self.sended_but_unfinished_reqs.add(item) + else: + sended_and_finished.add(item) + + return sended_and_finished, done_recving + + +def get_zmq_rpc_path_mooncake( + vllm_config: Optional["VllmConfig"] = None, ) -> str: + base_url = envs.VLLM_RPC_BASE_PATH + # Default to 0 if not configured + rpc_port = 0 + if vllm_config is not None: + rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( + "mooncake_rpc_port", 0) + logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) + return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" + + +class MooncakeStoreConnectorV1Scheduler: + + def __init__(self, vllm_config: "VllmConfig", use_layerwise): + self.client = MooncakeLookupClient(vllm_config) + self.use_layerwise = use_layerwise + self.kv_role = vllm_config.kv_transfer_config.kv_role + # request_id -> (vllm cached tokes, mooncake cached tokens) + self.load_specs: dict[str, LoadSpec] = {} + self._block_size = vllm_config.cache_config.block_size + # request_id -> full_token_ids + self._request_trackers: dict[str, RequestTracker] = {} + # Whether to discard partial chunks + self._discard_partial_chunks = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", True)) + self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} + self._unfinished_request_ids: set[str] = set() + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Check for external KV cache hit. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + + if self._discard_partial_chunks: + token_block_end = len(request.prompt_token_ids + ) // self._block_size * self._block_size + token_ids = torch.tensor( + request.prompt_token_ids[:token_block_end]) + else: + token_ids = torch.tensor(request.prompt_token_ids) + + num_external_hit_tokens = self.client.lookup(token_ids) + + if num_external_hit_tokens == request.num_tokens: + num_external_hit_tokens -= 1 + + need_to_allocate = num_external_hit_tokens - num_computed_tokens + + logger.info( + "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", + request.request_id, + request.num_tokens, + num_external_hit_tokens, + need_to_allocate, + ) + + if need_to_allocate <= 0: + return 0, False + + self.load_specs[request.request_id] = LoadSpec( + vllm_cached_tokens=num_computed_tokens, + mooncake_cached_tokens=num_external_hit_tokens, + can_load=False, + ) + + return need_to_allocate, not self.use_layerwise + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after temporary buffer alloc. + + For SharedStorageConnector, update _request_needs_load + if the CacheManager this allocated blocks for us. + """ + local_block_ids = [] + if num_external_tokens > 0: + local_block_ids = blocks.get_block_ids()[0] + + self._unfinished_requests[request.request_id] = (request, + local_block_ids) + self._unfinished_request_ids.add(request.request_id) + if request.request_id not in self.load_specs: + # No KV tokens from external KV cache, return + return + + if num_external_tokens == 0: + # No need to load anything + self.load_specs[request.request_id].can_load = False + return + + assert ( + num_external_tokens > 0 and num_external_tokens + == self.load_specs[request.request_id].mooncake_cached_tokens - + self.load_specs[request.request_id].vllm_cached_tokens + ), (f"Mismatch in number of tokens: {num_external_tokens} vs " + f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " + f"{self.load_specs[request.request_id].vllm_cached_tokens}" + f" for request {request.request_id}") + + self.load_specs[request.request_id].can_load = True + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """Attach the connector metadata to the request object. + + This function should NOT modify other fields in the scheduler_output + except the `kv_connector_metadata` field. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + force_skip_save = self.kv_role == "kv_consumer" + + for finished_req_id in scheduler_output.finished_req_ids: + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + self._unfinished_request_ids.remove(finished_req_id) + + meta = MooncakeConnectorMetadata(self._unfinished_request_ids) + + for request in scheduler_output.scheduled_new_reqs: + # Right now, we only load KV for new requests + load_spec = self.load_specs.pop(request.req_id, None) + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id]) + request_tracker = RequestTracker.from_new_request( + request, num_tokens_to_compute) + self._request_trackers[request.req_id] = request_tracker + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else len( + request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + cached_reqs = scheduler_output.scheduled_cached_reqs + if isinstance(cached_reqs, list) and not force_skip_save: + for i, req in enumerate(cached_reqs): + request_tracker = self._request_trackers[req.req_id] + request_tracker.update(req.new_token_ids, req.new_block_ids) + last_chunk_tokens_num = ((len(req.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(req.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + elif not force_skip_save: + for i, req_id in enumerate(cached_reqs.req_ids): + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_tuple = self._unfinished_requests.get(req_id) + if req_tuple: + request = req_tuple[0] + num_current_tokens = len(request_tracker.token_ids) + new_token_ids = request.all_token_ids[ + num_current_tokens:num_current_tokens + num_new_tokens] + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached") + new_block_ids = cached_reqs.new_block_ids[i] + if not new_block_ids: + continue + request_tracker.update(new_token_ids, new_block_ids) + # decode not save + if len(request_tracker.token_ids) > len( + request.prompt_token_ids): + continue + + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + request_ids = [ + req.req_id for req in scheduler_output.scheduled_new_reqs + ] + for request_id, (request, + block_ids) in self._unfinished_requests.items(): + if request_id not in request_ids and request_id not in cached_reqs.req_ids: + load_spec = self.load_specs.pop(request_id, None) + if not load_spec: + continue + num_tokens_to_compute = load_spec.mooncake_cached_tokens + if (num_tokens_to_compute % self._block_size + != 0) and (num_tokens_to_compute + == len(request.prompt_token_ids) - 1): + num_tokens_to_compute = num_tokens_to_compute + 1 + request_tracker = RequestTracker( + req_id=request_id, + token_ids=request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=block_ids, + num_saved_tokens=0, + ) + + self._request_trackers[request_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=None, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + if self.kv_role == "kv_consumer": + return False, None + if self._request_trackers[request.request_id].num_saved_tokens <= 0: + return False, None + delay_free_blocks = len(block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(block_ids), request.request_id) + return delay_free_blocks, None + + +class MooncakeLookupClient: + + def __init__(self, vllm_config: "VllmConfig"): + self.encoder = MsgpackEncoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REQ, # type: ignore[attr-defined] + bind=False, + ) + + def lookup(self, token_ids: torch.Tensor) -> int: + request = self.encoder.encode(token_ids) + self.socket.send_multipart(request, copy=False) + resp = self.socket.recv() + result = int.from_bytes(resp, "big") + return result + + def close(self): + self.socket.close(linger=0) + + +class MooncakeLookupServer: + + def __init__( + self, + mooncake_engine: MooncakeEngine, + vllm_config: "VllmConfig", + use_layerwise: bool, + ): + self.decoder = MsgpackDecoder(torch.Tensor) + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.mooncake_engine = mooncake_engine + self.running = True + + def process_request(): + while self.running: + frames = self.socket.recv_multipart(copy=False) + token_ids = self.decoder.decode(frames) + result = self.mooncake_engine.lookup(token_ids, use_layerwise) + response = result.to_bytes(4, "big") + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 4faf37d..c0fc1a6 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -11,7 +11,7 @@ from collections import defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple import msgspec import numpy as np @@ -19,6 +19,7 @@ import numpy.typing as npt import torch import zmq from mooncake.engine import TransferEngine # type: ignore +from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -29,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -67,12 +69,16 @@ class KVCacheTaskTracker: # intentionally delayed. Each entry is a tuple of (request_id, # timestamp). If a request remains in this queue for too long, it will # be force-freed. - self.delayed_free_requests: deque[Tuple[str, float]] = deque() + self.record_finished_requests: set[str] = set() + self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() def update_done_task_count(self, request_id: str): with self.done_task_lock: self.finished_requests.add(request_id) - self._remove_delayed_requests(request_id) + if request_id in self.delayed_free_requests: + self._remove_delayed_requests(request_id) + else: + self.record_finished_requests.add(request_id) def get_and_clear_finished_requests(self) -> set[str]: """ @@ -90,7 +96,10 @@ class KVCacheTaskTracker: def add_delayed_request(self, request_id: str, delay_start_time: float): """Add a delayed free request.""" with self.done_task_lock: - self.delayed_free_requests.append((request_id, delay_start_time)) + if request_id not in self.record_finished_requests: + self.delayed_free_requests[request_id] = delay_start_time + else: + self.record_finished_requests.discard(request_id) def _retrieve_expired_requests(self): """Retrieve all expired delayed requests.""" @@ -98,10 +107,11 @@ class KVCacheTaskTracker: # Free delayed requests if they exceed the timeout current_time = time.time() while self.delayed_free_requests: - request_id, delay_start_time = self.delayed_free_requests[0] + request_id = next(iter(self.delayed_free_requests)) + delay_start_time = self.delayed_free_requests[request_id] if (current_time - delay_start_time - > envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT): - self.delayed_free_requests.popleft() + > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): + self.delayed_free_requests.popitem(last=False) expired_requests.add(request_id) logger.info("Force freed request: %s", request_id) else: @@ -110,8 +120,7 @@ class KVCacheTaskTracker: def _remove_delayed_requests(self, request_id: str): """Remove all delayed free requests matching the given request_id.""" - self.delayed_free_requests = deque( - (r, t) for r, t in self.delayed_free_requests if r != request_id) + self.delayed_free_requests.pop(request_id) class KVCacheSendingThread(threading.Thread): @@ -230,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread): self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 + self.use_sfa = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() # TODO(jianzs): make this configurable @@ -341,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread): src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): - block_len = (self.block_len[k % 2] - if self.use_mla else self.block_len[0]) + if self.use_mla: + block_len = (self.block_len[k % 2]) + elif self.use_sfa: + block_len = (self.block_len[k % 3]) + else: + block_len = (self.block_len[0]) for i, remote_block_id in enumerate(grouped_remote_block_ids): local_block_ids = grouped_local_block_ids[i] src = src_layer_base_addr + local_block_ids[0] * block_len @@ -559,6 +573,7 @@ class MooncakeConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + self.ascend_config = get_ascend_config() self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id logger.info("Initializing Mooncake Scheduler %s", engine_id) @@ -718,7 +733,7 @@ class MooncakeConnectorScheduler: assert "tp_size" in decode_parallel_config.keys() self._decode_tp_size = decode_parallel_config["tp_size"] - if self.vllm_config.model_config.use_mla: + if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: return self._decode_tp_size else: # TODO support mha and gqa @@ -782,10 +797,12 @@ class MooncakeConnectorWorker: assert len(device_ids) > self.tp_rank # type: ignore self.device_id = device_ids[self.tp_rank] # type: ignore - self._initialize( - hostname=self.side_channel_host + ':' + '0' + ':' + 'npu_' \ - + str(self.device_id), - device_name=None) + if vllm_config.kv_transfer_config.get_from_extra_config( + 'use_ascend_direct', False): + hostname = self.side_channel_host + else: + hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" + self._initialize(hostname=hostname, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. @@ -837,7 +854,9 @@ class MooncakeConnectorWorker: # TODO(tms): Find a more robust way to detect and handle MLA self.use_mla = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sfa = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -851,6 +870,21 @@ class MooncakeConnectorWorker: logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) + elif self.use_sfa: + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe), + first_kv_cache[2].element_size() * math.prod(block_shape_k) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", + self.num_blocks, block_shape_norm, block_shape_pe, + block_shape_k) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -861,8 +895,9 @@ class MooncakeConnectorWorker: logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - logger.info("Registering KV_Caches. use_mla: %s, shape %s", - self.use_mla, first_kv_cache.shape) + logger.info( + "Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s", + self.use_mla, self.use_sfa, first_kv_cache.shape) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -874,9 +909,16 @@ class MooncakeConnectorWorker: region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) self._register(base_addr, region_len) + elif self.use_sfa: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 3] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) else: - cache_list = [cache_or_caches - ] if self.use_mla else cache_or_caches + cache_list = [ + cache_or_caches + ] if self.use_mla or self.use_sfa else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index f81d501..07c707e 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -11,7 +11,7 @@ from vllm_ascend.ascend_config import get_ascend_config # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None _MLP_TP: Optional[GroupCoordinator] = None - +_OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None @@ -20,6 +20,12 @@ def get_mc2_group() -> GroupCoordinator: return _MC2 +def get_otp_group() -> GroupCoordinator: + assert _OTP is not None, ( + "output tensor parallel group is not initialized") + return _OTP + + def get_lmhead_tp_group() -> GroupCoordinator: assert _LMTP is not None, ( "lm head tensor parallel group is not initialized") @@ -74,6 +80,20 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="mlp_tp") + # If oproj tensor parallel size is set, we will create a group for it. + otp_size = get_ascend_config().oproj_tensor_parallel_size + if otp_size is not None: + group_ranks = [] + global _OTP + num_oproj_tensor_parallel_groups: int = (world_size // otp_size) + for i in range(num_oproj_tensor_parallel_groups): + ranks = list(range(i * otp_size, (i + 1) * otp_size)) + group_ranks.append(ranks) + _OTP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="otp") + lmhead_tensor_parallel_size = get_ascend_config( ).lmhead_tensor_parallel_size if lmhead_tensor_parallel_size is not None: @@ -117,3 +137,8 @@ def destroy_ascend_model_parallel(): if _LMTP: _LMTP.destroy() _LMTP = None + + global _OTP + if _OTP: + _OTP.destroy() + _OTP = None diff --git a/vllm_ascend/distributed/tensor_parallel.py b/vllm_ascend/distributed/tensor_parallel.py deleted file mode 100644 index 3fff0a7..0000000 --- a/vllm_ascend/distributed/tensor_parallel.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py. -# This file is a part of the vllm-ascend project. -import torch - - -def _gather_along_first_dim(input_, group, output_split_sizes=None): - """Gather tensors and concatenate along the first dimension. - - Args: - input_tensor (torch.Tensor): - A tensor to be gathered. - output_split_sizes (List[int], optional): - A list specifying the sizes of the output splits along the first dimension. - If None, equal splitting is assumed. Default: None. - - Returns: - torch.Tensor: Gathered tensor. - """ - world_size = torch.distributed.get_world_size(group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - if output_split_sizes is None: - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.all_gather_into_tensor(output, - input_.contiguous(), - group=group) - else: - dim_size[0] = sum(output_split_sizes) - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - output_tensor_list = list( - torch.split(output, output_split_sizes, dim=0)) - torch.distributed.all_gather(output_tensor_list, input_, group=group) - - return output - - -def _gather_along_last_dim(input_, group): - """Gather tensors and concatenate along the last dimension.""" - - world_size = torch.distributed.get_world_size(group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.all_gather_into_tensor(output, - input_.contiguous(), - group=group) - tensor_list = output.chunk(world_size, dim=0) - output = torch.cat(tensor_list, dim=-1).contiguous() - - return output - - -def _reduce_scatter_along_first_dim(input_, - group, - input_split_sizes=None, - use_global_buffer=False): - """Reduce-scatter the input tensor across model parallel group. - - Args: - input_ (torch.Tensor): The input tensor to be reduce-scattered. - input_split_sizes (List[int], optional): A list specifying the sizes of - the input splits along the first dimension for each rank. If None, - equal splitting is assumed. Default: None. - """ - world_size = torch.distributed.get_world_size(group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - if input_split_sizes is None: - dim_size = list(input_.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" - - dim_size[0] = dim_size[0] // world_size - - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.reduce_scatter_tensor(output, - input_.contiguous(), - group=group) - else: - rank = torch.distributed.get_rank(group) - input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0)) - - output = torch.empty_like(input_tensor_list[rank]) - torch.distributed.reduce_scatter(output, - input_tensor_list, - group=group) - return output - - -def _reduce_scatter_along_last_dim(input_, group): - """Reduce-scatter tensors on the last dimension.""" - world_size = torch.distributed.get_world_size(group) - target_shape = list(input_.size()) - target_shape[-1] = target_shape[-1] // world_size - input_ = input_.reshape(-1, input_.shape[-1]) - split_tensors = torch.split(input_, - split_size_or_sections=input_.shape[-1] // - world_size, - dim=1) - concat_tensor = torch.cat(split_tensors, dim=0) - output = _reduce_scatter_along_first_dim(concat_tensor, - group).reshape(target_shape) - return output - - -def all_gather_last_dim_from_tensor_parallel_region(input_, group): - """Wrapper for autograd function: forward: AG, backward RS """ - return _gather_along_last_dim(input_, group) - - -def reduce_scatter_to_sequence_parallel_region(input_, - group, - input_split_sizes=None): - """Wrapper for autograd function: forward: RS, backward AG """ - return _reduce_scatter_along_first_dim(input_, group, input_split_sizes) - - -def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group): - """Wrapper for autograd function: forward: RS, backward AG: AG """ - return _reduce_scatter_along_last_dim(input_, group) - - -def gather_from_sequence_parallel_region( - input_, - group, - output_split_sizes=None, -): - """Wrapper for autograd function: forward: AG, backward: RS """ - return _gather_along_first_dim(input_, group, output_split_sizes) - - -def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None): - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input - - input = input.contiguous() - if output_split_sizes is None: - # Equal split (all2all) - output = torch.empty_like(input) - else: - # Unequal split (all2all-v) - output = input.new_empty( - size=[sum(output_split_sizes)] + list(input.size()[1:]), - dtype=input.dtype, - device=torch.npu.current_device(), - ) - torch.distributed.all_to_all_single( - output, - input, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - ) - return output - - -def all_to_all_sp2hp(input_, group): - """ - Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape - [num_tokens/TP, H] to [num_tokens, H/TP]. - - Args: - input_ (torch.Tensor): - The input tensor which has been distributed along the sequence - dimension. - - Returns: - torch.Tensor: The output tensor with shape [num_tokens, H/TP]. - - """ - if group is None: - return input_ - world_size = torch.distributed.get_world_size(group=group) - tp_group = group - input_ = input_.reshape(-1, input_.shape[-1]) - split_tensors = torch.split(input_, - split_size_or_sections=input_.shape[-1] // - world_size, - dim=1) - concat_tensor = torch.cat(split_tensors, dim=0) - output = all_to_all(tp_group, concat_tensor) - return output - - -def all_to_all_hp2sp(input_, group): - """ - Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape - [num_tokens, H/TP] to [num_tokens/TP, H]. - - Args: - input_ (torch.Tensor): - The input tensor which has been distributed along the hidden - dimension. - - Returns: - torch.Tensor: The output tensor with shape [num_tokens/TP, H]. - """ - if group is None: - return input_ - world_size = torch.distributed.get_world_size(group=group) - input_ = input_.reshape(-1, input_.shape[-1]) - tp_group = group - input_exchanged = all_to_all(tp_group, input_) - input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1]) - split_tensors = torch.split( - input_reshaped, - split_size_or_sections=input_reshaped.shape[0] // world_size, - dim=0) - output = torch.cat(split_tensors, dim=-1) - return output diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 78f8c50..2db4515 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -131,6 +131,26 @@ env_variables: Dict[str, Callable[[], Any]] = { # this feature is supported in A2, and eager mode will get better performance. "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), + # Whether to enable FlashComm optimization when tensor parallel is enabled. + # This feature will get better performance when concurrency is large. + "VLLM_ASCEND_ENABLE_FLASHCOMM": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), + # Whether to enable MLP weight prefetch, only used in small concurrency. + "VLLM_ASCEND_ENABLE_PREFETCH_MLP": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), + # buffer size for gate up prefetch + "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": + lambda: int( + os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), + # buffer size for down proj prefetch + "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": + lambda: int( + os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), + # Whether to enable dense model and general optimizations for better performance. + # Since we modified the base parent class `linear`, this optimization is also applicable to other model types. + # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. + "VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))), # Whether to enable mlp optimize when tensor parallel is enabled. # this feature in eager mode will get better performance. "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": @@ -139,11 +159,16 @@ env_variables: Dict[str, Callable[[], Any]] = { # caused by the initialization of the Mooncake connector. "PHYSICAL_DEVICES": lambda: os.getenv("PHYSICAL_DEVICES", None), + # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. + "MSMONITOR_USE_DAEMON": + lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), # Timeout (in seconds) for delayed KVCache block release. In the prefill # node, if a request is marked for delayed KV block release and the blocks # are not freed within this timeout, they will be forcibly released. "VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT": lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), + "VLLM_ASCEND_ENABLE_MLAPO": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), } # end-env-vars-definition @@ -157,4 +182,4 @@ def __getattr__(name: str): def __dir__(): - return list(env_variables.keys()) + return list(env_variables.keys()) \ No newline at end of file diff --git a/vllm_ascend/eplb/__init__.py b/vllm_ascend/eplb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/eplb/adaptor/__init__.py b/vllm_ascend/eplb/adaptor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/eplb/adaptor/abstract_adaptor.py b/vllm_ascend/eplb/adaptor/abstract_adaptor.py new file mode 100644 index 0000000..ab37fde --- /dev/null +++ b/vllm_ascend/eplb/adaptor/abstract_adaptor.py @@ -0,0 +1,44 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor. +from abc import abstractmethod +from typing import Any + + +class EplbAdaptor(): + + def __init__(self, **args): + pass + + @abstractmethod + def get_rank_expert_workload(self): + raise NotImplementedError + + @abstractmethod + def get_init_expert_map(self, num_moe_layers: Any) -> Any: + raise NotImplementedError + + @abstractmethod + def do_update_expert_map(self, layer_id: Any, + updated_expert_map: Any) -> Any: + raise NotImplementedError + + @abstractmethod + def do_update_expert_weight(self, layer_id: Any, + local_expert_to_replace: Any, + buffer_tensor_id: Any) -> Any: + raise NotImplementedError diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py new file mode 100644 index 0000000..d5ac509 --- /dev/null +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -0,0 +1,289 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor. +import json +from typing import Any + +import torch +import torch.distributed as dist +from vllm.logger import logger + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor + + +class VllmEplbAdaptor(EplbAdaptor): + + def __init__(self, model, **args): + super().__init__(**args) + self.model = model + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + self.param_dict = dict(self.model.named_parameters()) + if self.model.config.model_type == "qwen3_moe": + self.num_dense_layers = 0 + self.global_expert_num = self.model.config.num_experts + else: + self.num_dense_layers = self.model.config.first_k_dense_replace + self.global_expert_num = self.model.config.n_routed_experts + self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers + self.init_redundancy_expert = get_ascend_config( + ).init_redundancy_expert + + # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here + if self.model.quant_config is not None: + self.expert_weight_names = [ + "w13_weight", "w2_weight", "w13_weight_scale", + "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + ] + else: + self.expert_weight_names = ["w13_weight", "w2_weight"] + + self.expert_map_per_layer = dict( + ) # reference to expert map on device for expert map update + self.expert_map_per_layer_cpu = dict( + ) # copy of expert map on CPU to avoid device synchronize frequently + for layer_idx in range(self.num_moe_layers): + self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \ + self.model.get_expert_map(self.num_dense_layers + layer_idx) + + # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved + num_buffer_tensor = torch.where( + self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + self.buffer_tensor_list: list[list[Any]] = [ + [] for _ in range(num_buffer_tensor) + ] + self.init_buffer_tensor(num_buffer_tensor) + + self.expert_param_per_layer = dict() + self.init_expert_param_per_layer() + + self.log2phy_map_per_layer = dict() + for layer_idx in range(self.num_moe_layers): + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ + self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + self.all_topk_ids = [] + + def init_buffer_tensor(self, num_buffer_tensor): + for name in self.expert_weight_names: + complete_name = "model.layers." + str( + self.num_dense_layers) + ".mlp.experts." + name + expert_tensor = self.param_dict[complete_name].data[ + 0:num_buffer_tensor] + buffer_tensors = torch.empty_like(expert_tensor) + for buffer_id in range(num_buffer_tensor): + self.buffer_tensor_list[buffer_id].append( + buffer_tensors[buffer_id]) + + def init_expert_param_per_layer(self): + num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \ + ".mlp.experts." + self.expert_weight_names[0]].data.shape[0] + for moe_layer_id in range(self.num_moe_layers): + layer_idx = self.num_dense_layers + moe_layer_id + self.expert_param_per_layer[layer_idx] = list() + for local_expert_id in range(num_local_expert): + self.expert_param_per_layer[layer_idx].append([ + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name].data[local_expert_id] + for name in self.expert_weight_names + ]) + + def get_rank_expert_workload(self) -> torch.Tensor: + self.moe_load = self.model.get_all_moe_loads() + return self.moe_load + + def get_init_expert_map(self, num_moe_layers): + expert_map = self.model.get_all_expert_map(num_moe_layers) + if dist.is_initialized(): + world_size = dist.get_world_size() + + gathered = torch.empty( + (world_size, *expert_map.shape), # [W, L, E] + dtype=expert_map.dtype, + device=expert_map.device) + + dist.all_gather_into_tensor(gathered, expert_map) + all_maps = gathered.permute(1, 0, 2) + all_expert_maps = all_maps.cpu() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \ + all_expert_maps[layer_idx][self.rank_id] + + return all_expert_maps + + def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path): + + try: + expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor( + expert_map_path) + expert_map_all = self.local2global(expert_map_tensor) + except (TypeError, FileNotFoundError, OSError): + expert_map_all = self.determine_expert_map_all() + + for layer_idx in range(num_moe_layers): + if self.model.config.model_type == "qwen3_moe": + self.expert_map_per_layer_cpu[layer_idx] = \ + expert_map_all[layer_idx][self.rank_id] + else: + self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \ + expert_map_all[layer_idx][self.rank_id] + return expert_map_all + + def _expert_file_to_tensor(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + logger.error(f"failed to read expert_map_path: {expert_map_path}") + + def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): + if self.rank_id == 0: + num_local_experts = expert_maps.max() + 1 + expert_maps_local = self.global2local(expert_maps, + num_local_experts) + + expert_maps_list = expert_maps_local.tolist() + record: dict[str, Any] = { + "moe_layer_count": len(expert_maps_list), + "layer_list": [] + } + + for layer_idx, layer_data in enumerate(expert_maps_list): + layer_record: dict[str, Any] = { + "layer_id": layer_idx, + "device_count": len(layer_data), + "device_list": [] + } + + for device_idx, experts in enumerate(layer_data): + device_record = { + "device_id": device_idx, + "device_expert": experts + } + layer_record["device_list"].append(device_record) + + record["layer_list"].append(layer_record) + + with open(expert_map_record_path, "w") as f: + json.dump(record, f, indent=4) + + def do_update_expert_map(self, layer_id, updated_expert_map): + self.expert_map_per_layer[layer_id] = updated_expert_map.clone() + self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone() + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, + buffer_tensor_id): + for expert_tensor, buffer_tensor in zip( + self.expert_param_per_layer[layer_id][local_expert_to_replace], + self.buffer_tensor_list[buffer_tensor_id]): + expert_tensor = buffer_tensor.clone() + logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") + + def do_update_log2phy_map(self, layer_id, updated_log2phy_map): + if self.log2phy_map_per_layer[layer_id] is not None: + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map) + + def global2local(self, placement: torch.Tensor, + E_local: int) -> torch.Tensor: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def determine_expert_map_all(self): + if self.world_size == 1: + local_ids = torch.arange(self.global_expert_num, dtype=torch.int32) + return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1) + + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.num_moe_layers, self.world_size, self.global_expert_num), + -1, + dtype=torch.int32) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + if r < self.init_redundancy_expert: + local_count += 1 + if end < self.global_expert_num: + end += 1 + else: + start -= 1 + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand( + self.num_moe_layers, -1) + + return expert_map_all diff --git a/vllm_ascend/eplb/core/__init__.py b/vllm_ascend/eplb/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py new file mode 100644 index 0000000..a170987 --- /dev/null +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -0,0 +1,137 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from enum import Enum + +import torch.distributed as dist +from vllm.logger import logger + + +class ExpertWeightUpdateState(Enum): + WAITING = 0 # waiting for updated expert_map by EplbWorker + READY = 1 # ready for d2d expert weights updating + TRANSFERRING = 2 # d2d finished and waiting for updating expert_map into model + + +class D2DExpertWeightLoader: + + def __init__(self): + self.comm_op_list = None + self.updated_expert_map = None + self.updated_log2phy_map = None + self.layer_id = -1 # layer id to be updated + self.state = ExpertWeightUpdateState.WAITING + self.recv_expert_list = [] + self.mock_flag = True + + def set_adator(self, eplb_adaptor): + self.eplb_adaptor = eplb_adaptor + + def generate_expert_d2d_transfer_task(self, expert_send_info, + expert_recv_info, updated_expert_map, + layer_id): + # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task + if self.state != ExpertWeightUpdateState.WAITING: + logger.error( + "current d2d weight update tasks are on-going, cannot accept new weight update task" + ) + return + + # If neither send nor receive task is needed for this layer on this rank, return + if not (expert_send_info or expert_recv_info): + return + + self.updated_expert_map = updated_expert_map + + self.layer_id = layer_id + self.comm_op_list = [] + for send_info in expert_send_info: + dst_rank, global_expert_id_to_send = send_info + local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[ + layer_id][global_expert_id_to_send].item() + for src_tensor in self.eplb_adaptor.expert_param_per_layer[ + layer_id][local_expert_id]: + self.comm_op_list.append( + dist.P2POp(dist.isend, src_tensor, dst_rank)) + + buffer_tensor_id = 0 + for recv_info in expert_recv_info: + recv_rank, global_expert_id_to_recv = recv_info + for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[ + buffer_tensor_id]: + self.comm_op_list.append( + dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + local_expert_to_replace = self.updated_expert_map[ + global_expert_id_to_recv].item() + self.recv_expert_list.append( + (local_expert_to_replace, buffer_tensor_id)) + buffer_tensor_id += 1 + + self.state = ExpertWeightUpdateState.READY + + def set_log2phy_map(self, log2phy_map): + self.updated_log2phy_map = log2phy_map + + def asyn_expert_weight_transfer(self, reqs): + # Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched + if self.state != ExpertWeightUpdateState.READY: + return + + # set asynchronous stream for d2d expert weight transfer + if self.comm_op_list: + ret_list = dist.batch_isend_irecv(self.comm_op_list) + reqs.extend(ret_list) + + self.state = ExpertWeightUpdateState.TRANSFERRING + + def update_expert_map_and_weight(self, reqs): + # Only after send/recv tasks have been luanched, expert_map and weight can be updated + if self.state != ExpertWeightUpdateState.TRANSFERRING: + return + + # Waiting for send/recv tasks finish + for req in reqs: + req.wait() + + if self.comm_op_list is not None: + self.comm_op_list = None + + # update expert_map + self.eplb_adaptor.do_update_expert_map(self.layer_id, + self.updated_expert_map) + + # update log2phy_map + self.eplb_adaptor.do_update_log2phy_map(self.layer_id, + self.updated_log2phy_map) + + # update expert weight + buffer_tensor_id = 0 + for recv_expert_info in self.recv_expert_list: + local_expert_to_replace, buffer_tensor_id = recv_expert_info + self.eplb_adaptor.do_update_expert_weight(self.layer_id, + local_expert_to_replace, + buffer_tensor_id) + + logger.info( + f"[EPLB] finished update expert weight for layer: {self.layer_id}") + + self.recv_expert_list = [] + self.updated_expert_map = None + self.layer_id = -1 + self.state = ExpertWeightUpdateState.WAITING + + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py new file mode 100644 index 0000000..9a1c3bd --- /dev/null +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -0,0 +1,135 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils. +import random + +import torch +from vllm.logger import logger + + +def determine_default_expert_map(global_expert_num, world_size, rank_id, + global_redundant_expert_num): + if world_size == 1: + local_ids = torch.arange(global_expert_num, dtype=torch.int32) + return (global_expert_num, local_ids) + + local_num_experts = global_expert_num // world_size + + expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32) + + if rank_id < world_size - 1: + start = rank_id * local_num_experts + end = (rank_id + 1) * local_num_experts + local_count = local_num_experts + else: + start = rank_id * local_num_experts + end = global_expert_num + local_count = global_expert_num - rank_id * local_num_experts + + if isinstance(global_redundant_expert_num, + int) and rank_id < global_redundant_expert_num: + local_count += 1 + if end < global_expert_num: + end += 1 + else: + start -= 1 + + if isinstance(local_count, int): + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map[start:end] = local_ids + + return (local_count, expert_map) + + +def generate_log2phy_map(expert_map): + num_local_experts = expert_map.max() + 1 + log2phy_map = expert_map.clone() + num_ranks, num_global_expert = log2phy_map.shape + + row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \ + num_global_expert) * num_local_experts + log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1] + + for idx in range(num_global_expert): + positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0] + negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0] + num_rank_holding_expert = positive_rank_idx.size(0) + + if num_rank_holding_expert == 0: + log2phy_map[:, idx] = torch.full((num_ranks, ), + 0, + dtype=log2phy_map.dtype) + + if num_rank_holding_expert == 1: + log2phy_map[negative_rank_idx, idx] = torch.full( + (num_ranks - 1, ), + log2phy_map[positive_rank_idx, idx].item(), + dtype=log2phy_map.dtype) + else: + try: + random_list = [ + random.choice(log2phy_map[positive_rank_idx, idx]) + for _ in range(num_ranks - num_rank_holding_expert) + ] + log2phy_map[negative_rank_idx, + idx] = torch.tensor(random_list, + dtype=log2phy_map.dtype) + except Exception as e: + logger.error(f"Fail to get log2phy_map: {str(e)}") + + return log2phy_map + + +def determine_default_log2phy_map(global_expert_num, world_size, rank_id, + global_redundant_expert_num): + if world_size == 1: + local_ids = torch.arange(global_expert_num, dtype=torch.int32) + expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1) + log2phy_map_all = generate_log2phy_map(expert_map_all) + return log2phy_map_all[rank_id] + + local_num_experts = global_expert_num // world_size + + expert_map_all = torch.full((world_size, global_expert_num), + -1, + dtype=torch.int32) + + for r in range(world_size): + if r < world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = global_expert_num + local_count = global_expert_num - r * local_num_experts + + if isinstance(global_redundant_expert_num, + int) and rank_id < global_redundant_expert_num: + local_count += 1 + if end < global_expert_num: + end += 1 + else: + start -= 1 + + if isinstance(local_count, int): + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[r, start:end] = local_ids + + log2phy_map_all = generate_log2phy_map(expert_map_all) + + return log2phy_map_all[rank_id] diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py new file mode 100644 index 0000000..cd460f8 --- /dev/null +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -0,0 +1,436 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from multiprocessing import Process, Queue +from typing import Any + +import networkx as nx # type: ignore +import numpy as np +import torch +import torch.distributed as dist +from vllm.logger import logger + +from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map +from vllm_ascend.eplb.core.policy.policy_factory import (DynamicConfig, + PolicyFactory) + + +class EplbWorker: + + def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): + self.policy_type = policy_type + self.policy = PolicyFactory.generate_policy(policy_type, + DynamicConfig()) + self.shared_dict = shared_dict + self.old_expert_maps = None + self.enable_d2d = enable_d2d + self.rank_id = dist.get_rank() + + def do_update(self): + # put data in to queue + # in process self.policy.generate_policy() + # get epxert table && tensor + + # async stream + # D2D + # H2D + # Get initial expert_map + torch.set_num_threads(1) + if self.old_expert_maps is None: + self.old_expert_maps = self.get_init_expert_maps() + if self.old_expert_maps is not None: + self.num_local_experts = self.old_expert_maps.max() + 1 + else: + raise ValueError("Failed to get expert_maps from shared_dict.") + + # Get MOE load information + load_info = self.fetch_and_sum_load_info() + if load_info is None: + return + + # Get the updated expert table based on the workload information + old_placement = self.global2local(self.old_expert_maps, + self.num_local_experts) + _, _, new_placement = self.calculate_rebalance_experts( + load_info, old_placement) + + if not torch.is_tensor(new_placement): + new_placement = torch.tensor(new_placement) + self.check_expert_placement(old_placement, new_placement) + new_expert_maps = self.local2global(new_placement) + self.update_expert_map(new_expert_maps) + + update_info = self.compose_expert_update_info_greedy( + new_expert_maps, self.old_expert_maps) + self.old_expert_maps = new_expert_maps + logger.info("EPLB Process compute complete") + + packed_update_info = self.pack_update_info(update_info) + + return packed_update_info + + def check_expert_placement(self, old_placement, new_placement): + num_layers = old_placement.shape[0] + num_ranks = old_placement.shape[1] + + for layer_id in range(num_layers): + # check if any logical expert is not placed on any rank + if torch.unique(new_placement[layer_id]).numel() < torch.unique( + old_placement[layer_id]).numel(): + logger.error( + f"There exists expert not placed on any rank in layer {layer_id}" + ) + new_placement[layer_id] = old_placement[layer_id] + continue + + for rank_id in range(num_ranks): + new_placement_check = new_placement[layer_id][rank_id] + old_placement_check = old_placement[layer_id][rank_id] + + # check if same logical experts are placed on the same NPU + if new_placement_check.numel() != torch.unique( + new_placement_check).numel(): + logger.error( + f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" + ) + new_placement[layer_id] = old_placement[layer_id] + break + + # check if there is any experts movement inside one NPU + expert_not_move = torch.isin(new_placement_check, + old_placement_check) + if not torch.equal(new_placement_check[expert_not_move], + old_placement_check[expert_not_move]): + logger.error( + f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" + ) + new_placement[layer_id] = old_placement[layer_id] + break + + def compose_expert_update_info_bipartite(self, updated_expert_maps_org, + current_expert_maps_org): + # transform numpy array to torch tensor + updated_expert_maps = updated_expert_maps_org.clone() + current_expert_maps = current_expert_maps_org.clone() + updated_expert_maps = np.array(updated_expert_maps) + current_expert_maps = np.array(current_expert_maps) + + num_layers = current_expert_maps.shape[0] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + updated_expert_maps_this_layer_org = updated_expert_maps_org[ + layer_id] + + from typing import Any + + expert_send_info_this_layer: dict[Any, Any] = {} + expert_recv_info_this_layer: dict[Any, Any] = {} + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if (np.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer)).all(): + yield (expert_send_info_this_layer, + expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = np.where( + (current_expert_maps_this_layer == -1) + & (updated_expert_maps_this_layer != -1)) + + # record src ranks for potential transfer + src_ranks_set = dict() + for idx in range(len(dst_rank_indices)): + expert_id = experts_to_recv[idx].item() + if expert_id not in src_ranks_set: + src_ranks_set[expert_id] = np.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + + # loop until all experts are scheduled + while len(dst_rank_indices) > 0: + # construct bipartite graph + graph_expert_update: nx.Graph = nx.Graph() + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + # add src ranks + src_rank_ids = src_ranks_set[expert_id] + graph_expert_update.add_nodes_from(src_rank_ids, + bipartite=0) + # add dest rank + graph_expert_update.add_node(str(dst_rank_id), bipartite=1) + # add edges + for src_rank_id in src_rank_ids: + graph_expert_update.add_edge(src_rank_id, + str(dst_rank_id)) + + # graph may not be connected + connected_components = list( + nx.connected_components(graph_expert_update)) + all_matches = {} + # matching in this loop + for i, component in enumerate(connected_components): + subgraph = graph_expert_update.subgraph(component) + component_matching = nx.bipartite.maximum_matching( + subgraph) + all_matches.update(component_matching) + + for src_rank, dst_rank in all_matches.items(): + dst_rank = int(dst_rank) + assert src_rank != dst_rank + if graph_expert_update.nodes[src_rank]['bipartite'] == 0: + # currently not scheduled experts in rank dst_rank + experts_v = experts_to_recv[np.where( + dst_rank_indices == dst_rank)] + # src: src_rank, dest: dst_rank, expert: expert_id + expert_id = np.intersect1d( + experts_v, + np.where(current_expert_maps_this_layer[src_rank] + != -1))[0] + + # record send/rcv pairs + if src_rank not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank] = [] + if dst_rank not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank] = [] + expert_send_info_this_layer[src_rank].append( + (dst_rank, expert_id)) + expert_recv_info_this_layer[dst_rank].append( + (src_rank, expert_id)) + + remove_index = np.where( + np.logical_and(dst_rank_indices == dst_rank, + experts_to_recv == expert_id)) + + # update + dst_rank_indices = np.delete(dst_rank_indices, + remove_index) + experts_to_recv = np.delete(experts_to_recv, + remove_index) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases + def compose_expert_update_info_greedy(self, updated_expert_maps, + current_expert_maps): + num_layers = current_expert_maps.shape[0] + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + + expert_send_info_this_layer: dict[Any, Any] = {} + expert_recv_info_this_layer: dict[Any, Any] = {} + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if torch.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer): + yield (expert_send_info_this_layer, + expert_recv_info_this_layer, + updated_expert_maps_this_layer, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ + & (updated_expert_maps_this_layer != -1)) + + # Parse expert_ids each rank needs to send to other ranks + src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ + & (updated_expert_maps_this_layer == -1)) + + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + if dst_rank_id not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank_id] = [] + + if not torch.isin(torch.tensor(expert_id), + experts_to_send).any(): + # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert + candidate_src_rank_indices = torch.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + else: + candidate_src_rank_indices = src_rank_indices[ + experts_to_send == expert_id] + + # TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... + src_rank_id = candidate_src_rank_indices[0].item() + if src_rank_id not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank_id] = [] + + expert_send_info_this_layer[src_rank_id].append( + (dst_rank_id, expert_id)) + expert_recv_info_this_layer[dst_rank_id].append( + (src_rank_id, expert_id)) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer, layer_id) + + def calculate_rebalance_experts(self, load_info, old_placement): + """ + Compute `new_map` by calling the `rebalance_experts` method of the policy instance. + """ + if self.old_expert_maps is None: + return False, None, None + + changed, priority, new_map = self.policy.rebalance_experts( + old_placement, load_info) + return changed, priority, new_map + + def get_init_expert_maps(self): + """ + Read the initial expert_map from shared_dict. + """ + return self.shared_dict.get("expert_maps", None) + + def fetch_and_sum_load_info(self): + """ + Each time the subprocess is awakened, read the latest moe_load + (shape: [num_moe_layers, num_experts_per_layer]) from shared_dict. + """ + return self.shared_dict.get("moe_load", None) + + def update_expert_map(self, expert_maps): + + self.shared_dict["expert_maps"] = expert_maps + + def global2local(self, placement: torch.Tensor, + E_local: int) -> tuple[torch.Tensor, torch.Tensor]: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def pack_update_info(self, update_info_generator): + """ + Pack a list of update info tuples for efficient IPC. + """ + send_all = [] + recv_all = [] + maps = [] + log2phy_all = [] + layer_ids = [] + + for send_info, recv_info, new_expert_map, layer_id in update_info_generator: + + send_info_this_rank = send_info[ + self.rank_id] if self.rank_id in send_info else [] + recv_info_this_rank = recv_info[ + self.rank_id] if self.rank_id in recv_info else [] + send_all.append(send_info_this_rank) + recv_all.append(recv_info_this_rank) + + maps.append(new_expert_map[self.rank_id].numpy().tolist()) + + log2phy_map = generate_log2phy_map(new_expert_map) + log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist()) + + layer_ids.append(layer_id) + + return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + + +class EplbProcess: + + def __init__(self, + shared_dict, + policy_type: int = 0, + enable_d2d: bool = True): + """ + Args: + shared_dict: Cross-process shared dict returned by Manager().dict() + policy_type: Integer passed to PolicyFactory.generate_policy + enable_d2d: Whether to enable D2D loading + """ + self.shared_dict = shared_dict + self.policy_type = policy_type + self.enable_d2d = enable_d2d + self.planner_q: Queue[Any] = Queue() + self.block_update_q: Queue[Any] = Queue(maxsize=1) + + # Create EplbWorker instance + self.worker = EplbWorker(self.shared_dict, self.policy_type, + self.enable_d2d) + + def worker_process(self, planner_q, block_update_q): + """ + Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. + """ + while True: + try: + planner_q.get() + + packed_update_info = self.worker.do_update() + + while True: + if not block_update_q.empty(): + continue + block_update_q.put(packed_update_info) + break + + except Exception as e: + logger.warning(f"[EPLB subprocess Exiting due to error: {e}", + exc_info=True) + break + + def _launch_process(self): + """ + Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). + """ + proc = Process(target=self.worker_process, + args=(self.planner_q, self.block_update_q), + daemon=True) + + proc.start() + return proc diff --git a/vllm_ascend/eplb/core/policy/__init__.py b/vllm_ascend/eplb/core/policy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/eplb/core/policy/policy_abstract.py b/vllm_ascend/eplb/core/policy/policy_abstract.py new file mode 100644 index 0000000..8ef58e2 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_abstract.py @@ -0,0 +1,42 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host + ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed + num_die_per_host = 8 # Number of dies on each host machine + + +class EplbPolicy: + + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + Pass in the weights and return expert replication and placement under relevant constraints. + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass diff --git a/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py b/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py new file mode 100644 index 0000000..5e77f4d --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py @@ -0,0 +1,389 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +from collections import defaultdict +from typing import cast + +import numpy as np + +from .policy_abstract import DynamicConfig, EplbPolicy + + +class DynamicTable: + # workload_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + workload_table = None + + # placement_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + placement_table = None + + +class DynamicEplb(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, + num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict: dict[int, int] = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][ + expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # Split hot (high-load) experts into redundant experts + def original_compute_balanced_pack_redundancy(origin_weights, card_num, + num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy: list[list[int]] = [ + [] for _ in range(route_expert_num) + ] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], + kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * ( + len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / ( + len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes: list[list[int]] = [[] for _ in range(card_num)] + boxes_weights: list[list[float]] = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], + kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + if item_id in boxes[i]: + continue + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] + == items_per_box + and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[ + min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # Split hot (high-load) experts into redundant experts + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, + num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy: list[list[int]] = [ + [] for _ in range(route_expert_num) + ] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], + kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * ( + len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / ( + len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes: list[list[int]] = [[] for _ in range(card_num)] + boxes_weights: list[list[float]] = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num, ), dtype='object') + all_weights[:route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], + kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] + == items_per_box + and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[ + min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # Scheme without redundant experts + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes: list[list[int]] = [[] for _ in range(card_num)] + boxes_weights: list[list[float]] = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] + == items_per_box + and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[ + min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu: int = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer: list[float] = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, + global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [ + int(x) for x in current_expert_table[layer_id][card_id] + ] + new_list = [ + int(x) for x in global_deployment[layer_id][card_id] + ] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[ + j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + def rebalance_experts(self, current_expert_table, expert_workload): + + info = DynamicTable() + info.workload_table = np.array(expert_workload) + info.placement_table = np.array(current_expert_table) + assert info.workload_table is not None + layer_num, num_npus, experts_per_npu = info.workload_table.shape + assert info.placement_table is not None + row = cast(np.ndarray, info.placement_table[0]) + expert_ids, counts = np.unique(row, return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, + info.workload_table, + num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer( + info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # Perform load balancing and deploy redundant experts + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + # Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards + if num_original_expert != expert_num: + raise ValueError( + f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}" + ) + + if num_npus <= 0: + raise ValueError("the number of NPUs must be greater than 0") + + if num_npus < num_redundancy_expert: + raise ValueError( + f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}" + ) + + # Number of experts deployed on each card includes one redundant expert + global_deployment: list[list[list[int]]] = [[[] + for _ in range(num_npus)] + for _ in range(layer_num)] + # Iterate to obtain the placement strategy for each layer, taking computational balance into account + max_heat_per_layer_after = np.zeros([layer_num]) + for layer in range(layer_num): + # Get the expert IDs and their corresponding workloads for the current layer; + # workloads need to be normalized, and one redundant expert is added per card + weights = np.zeros((expert_num, ), dtype='object') + for expert_id, workload_weight in enumerate( + layer_workloads[layer]): + weights[expert_id] = (expert_id, workload_weight) + + # Obtain the globally balanced placement strategy for each layer + result, layer_deployment = self.original_compute_balanced_pack_redundancy( + weights, num_npus, num_redundancy_expert) + + global_deployment[layer] = layer_deployment + max_heat_per_layer_after[layer] = max( + result, key=lambda x: x['total_weight'])['total_weight'] + + new_global_deployment = self.constraint_expert_local_exchange( + current_expert_table, global_deployment) + # Obtain the priority of each layer + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / + max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array( + new_global_deployment).tolist() diff --git a/vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py b/vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py new file mode 100644 index 0000000..a0b8d5d --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py @@ -0,0 +1,771 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +from abc import abstractmethod +from collections import defaultdict + +import numpy as np + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host + ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed + num_die_per_host = 8 # Number of dies on each host machine + + +class EplbPolicy: + + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + Pass in the weights and return expert replication and placement under relevant constraints. + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass + + +class DynamicTable: + # workload_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + workload_table = None + + # placement_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + placement_table = None + + +class DynamicEplbV2(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def safe_divide(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a / b + + @staticmethod + def safe_exact_divide(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a // b + + @staticmethod + def safe_mod(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a % b + + @staticmethod + def add_redundant(current_expert_table, expert_workload, + num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict: dict[int, int] = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][ + expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu: int = int(np.sum(counts - 1)) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer: list[float] = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + def calculate_initial_imbalance(self, global_deployment, + new_layer_workloads): + + device_num = global_deployment.shape[1] + layer_imbalance = [] + expert_num = np.zeros_like(new_layer_workloads) + for layer_id, layer in enumerate(global_deployment): + for device in layer: + for expert_id in device: + expert_num[layer_id][expert_id] += 1 + + for layer_id, layer in enumerate(global_deployment): + cur_layer_max_workload = 0 + total_workload = 0 + for box in layer: + box_workload = 0 + for expert_id in box: + update_workload = self.safe_divide( + new_layer_workloads[layer_id][expert_id], + expert_num[layer_id][expert_id]) + box_workload += update_workload + total_workload += update_workload + if cur_layer_max_workload < box_workload: + cur_layer_max_workload = box_workload + + cur_layer_imbalance = self.safe_divide( + cur_layer_max_workload, + (self.safe_divide(total_workload, device_num))) + layer_imbalance.append(cur_layer_imbalance) + + return layer_imbalance + + def compute_redundant_assignments(self, base_experts, + num_redundant_experts, num_experts): + + redundant_assignments: list[list[int]] = [[] + for _ in range(num_experts)] + current_weights = base_experts.copy() + + for i in range(num_redundant_experts): + sorted_indices = np.argsort([w for _, w in current_weights], + kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + target_expert = sorted_weights[0] + expert_id, original_weight = target_expert + + current_redundancy = len(redundant_assignments[expert_id]) + new_avg_weight = self.safe_divide( + original_weight * (current_redundancy + 1), + (current_redundancy + 2)) + + redundant_assignments[expert_id].append(num_experts + i) + current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) + + sorted_indices = np.argsort([w for _, w in current_weights], + kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + return redundant_assignments, sorted_weights + + def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos, + num_experts, num_exist_expert, + device_assignments, device_counts, + expert_from_device, + com_between_devices): + + current_weights = np.zeros((num_experts, ), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + current_weights[expert_id] = (expert_id, workload_weight) + + devices_with_slots = [] + for device_id, device_rendun_pos in enumerate(rendun_pos): + if len(device_rendun_pos) != 0: + devices_with_slots.append(device_id) + + while devices_with_slots: + sorted_indices = np.argsort([w for _, w in current_weights], + kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + for index, target_weight in enumerate(sorted_weights): + expert_id, original_weight = target_weight + if original_weight == -1: + print("Error:Redundant expert failure re-occurred") + redundancy_successful = True + break + redundancy_successful = False + for cur_device_id in devices_with_slots: + if expert_id not in device_assignments[cur_device_id]: + pos = rendun_pos[cur_device_id].pop() + if len(rendun_pos[cur_device_id]) == 0: + devices_with_slots = [ + device_id for device_id in devices_with_slots + if device_id != cur_device_id + ] + device_assignments[cur_device_id][pos] = expert_id + device_counts[cur_device_id] += 1 + communication_box_index = expert_from_device[expert_id] + com_between_devices[cur_device_id][ + communication_box_index] = expert_id + new_weight = self.safe_divide( + (original_weight * num_exist_expert[expert_id]), + (num_exist_expert[expert_id] + 1)) + sorted_weights[index] = (expert_id, new_weight) + num_exist_expert[expert_id] += 1 + redundancy_successful = True + break + if redundancy_successful: + break + + sorted_indices = np.argsort([id for id, _ in sorted_weights], + kind='stable') + sorted_weights = [sorted_weights[i][1] for i in sorted_indices] + + return sorted_weights, device_assignments, device_counts, com_between_devices + + @staticmethod + def prepare_expert_list(base_experts, redundant_assignments, + num_redundant_experts): + redundant_expert_list = np.empty(num_redundant_experts, dtype=object) + + index = 0 + num_experts = len(redundant_assignments) + for expert_id in range(num_experts): + for _ in redundant_assignments[expert_id]: + redundant_expert_list[index] = (expert_id, + next(w + for eid, w in base_experts + if eid == expert_id)) + index += 1 + + sorted_indices = np.argsort([w for _, w in redundant_expert_list], + kind='stable')[::-1] + return [redundant_expert_list[i] for i in sorted_indices] + + @staticmethod + def non_redundant_expert_information(origin_deployment, updated_weights, + rendun_pos): + + device_num = len(origin_deployment) + num_experts_per_device = origin_deployment.shape[1] + device_assignments = [[-1 for _ in range(num_experts_per_device)] + for _ in range(device_num)] + device_weights = [[0 for _ in range(num_experts_per_device)] + for _ in range(device_num)] + device_loads = [0] * device_num + device_counts = [0] * device_num + + for device_id, device in enumerate(origin_deployment): + for index, expert_id in enumerate(device): + if index in rendun_pos[device_id]: + continue + device_assignments[device_id][index] = expert_id + cur_weight = next( + weight for expert_id_of_weight, weight in updated_weights + if expert_id_of_weight == expert_id) + device_weights[device_id][index] = cur_weight + device_loads[device_id] += cur_weight + device_counts[device_id] += 1 + + return device_assignments, device_weights, device_loads, device_counts + + def recomputing_initial_weight(self, layer_workloads, device_assignments): + num_all_experts = [0] * len(layer_workloads) + for device in device_assignments: + for expert_id in device: + if expert_id != -1: + num_all_experts[expert_id] += 1 + + cur_layer_workload = [] + for expert_id, weight in enumerate(layer_workloads): + if num_all_experts[expert_id] == 0: + cur_layer_workload.append(-1) + else: + cur_layer_workload.append( + self.safe_divide(weight, num_all_experts[expert_id])) + + return cur_layer_workload, num_all_experts + + def distribute_redun_experts(self, layer_workloads, device_assignments, + device_weights, device_loads, device_counts, + redundant_expert_list, expert_from_device, + num_experts, rendun_pos): + + num_devices = len(device_assignments) + com_between_devices: list[dict[int, + int]] = [{} for _ in range(num_devices)] + + for expert_id, weight in redundant_expert_list: + candidate = -1 + for dev_id in range(num_devices): + if len(rendun_pos[dev_id]) == 0: + continue + if expert_id in device_assignments[dev_id]: + continue + if candidate == -1 or device_loads[dev_id] < device_loads[ + candidate]: + candidate = dev_id + if candidate != -1: + pos = rendun_pos[candidate].pop() + device_assignments[candidate][pos] = expert_id + device_weights[candidate][pos] = weight + device_loads[candidate] += weight + device_counts[candidate] += 1 + + communication_box_index = expert_from_device[expert_id] + com_between_devices[candidate][ + communication_box_index] = expert_id + + if any(sublist for sublist in rendun_pos): + cur_layer_workload, num_exist_expert = self.recomputing_initial_weight( + layer_workloads, device_assignments) + + update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments( + cur_layer_workload, rendun_pos, num_experts, num_exist_expert, + device_assignments, device_loads, expert_from_device, + com_between_devices) + + device_loads = [0] * len(device_counts) + for device_id, device in enumerate(device_assignments): + for index, expert_id in enumerate(device): + device_weights[device_id][index] = update_workload[ + expert_id] + device_loads[device_id] += update_workload[expert_id] + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + def redundancy_again(self, layer_workloads, origin_weights, + origin_deployment, expert_from_device, num_node, + is_node_redundant, rendun_pos): + + num_experts = len(origin_weights) + if is_node_redundant: + num_experts = num_experts * num_node + + num_redundant_experts = 0 + for rank_empty_pos in rendun_pos: + num_redundant_experts += len(rank_empty_pos) + + redundant_assignments, updated_weights = self.compute_redundant_assignments( + origin_weights, num_redundant_experts, num_experts) + + redundant_expert_list = self.prepare_expert_list( + updated_weights, redundant_assignments, num_redundant_experts) + + device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( + origin_deployment, updated_weights, rendun_pos) + + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( + layer_workloads, device_assignments, device_weights, device_loads, + device_counts, redundant_expert_list, expert_from_device, + num_experts, rendun_pos) + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def generate_allocation_report(device_assignments, device_weights, + device_loads, device_counts): + + report = [] + max_load = 0.0 + + for dev_id in range(len(device_assignments)): + current_load = device_loads[dev_id] + max_load = max(max_load, current_load) + + report.append({ + "device_id": dev_id + 1, + "assigned_experts": device_assignments[dev_id], + "expert_weights": device_weights[dev_id], + "total_load": current_load, + "expert_count": device_counts[dev_id] + }) + + return report, max_load + + @staticmethod + def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, + next_device_id, cur_layer_result, com_between_devices): + + cur_device_deployment = cur_layer_result[cur_device_id][ + 'assigned_experts'] + next_device_deployment = cur_layer_result[next_device_id][ + 'assigned_experts'] + + cur_device_weight = cur_layer_result[cur_device_id]['expert_weights'] + next_device_weight = cur_layer_result[next_device_id]['expert_weights'] + + cur_expert_id = cur_device_deployment[cur_exchange_index] + next_expert_id = next_device_deployment[next_exchange_index] + cur_device_deployment[cur_exchange_index] = next_expert_id + next_device_deployment[next_exchange_index] = cur_expert_id + + cur_expert_weight = cur_device_weight[cur_exchange_index] + next_expert_weight = next_device_weight[next_exchange_index] + cur_device_weight[cur_exchange_index] = next_expert_weight + next_device_weight[next_exchange_index] = cur_expert_weight + + cur_layer_result[cur_device_id][ + 'total_load'] += next_expert_weight - cur_expert_weight + cur_layer_result[next_device_id][ + 'total_load'] += cur_expert_weight - next_expert_weight + + com_between_devices[cur_device_id][next_device_id] = next_expert_id + com_between_devices[next_device_id][cur_device_id] = cur_expert_id + + def redundant_expert_deployment(self, layer_workloads, original_deployment, + expert_from_device, node_num, + is_node_redundant, rendun_pos): + device_num, per_device_expert_num = original_deployment.shape + route_expert_num = layer_workloads.shape[0] + per_node_device_num = self.safe_exact_divide(device_num, node_num) + per_node_route_expert_num = per_node_device_num * ( + per_device_expert_num - 1) + + weights = np.zeros((route_expert_num, ), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + weights[expert_id] = (expert_id, workload_weight) + + if is_node_redundant: + + device_assignments = [] + device_weights = [] + device_loads = [] + device_counts = [] + com_between_devices = [] + + for node_id in range(node_num): + cur_node_weights = weights[node_id * + per_node_route_expert_num:(node_id + + 1) * + per_node_route_expert_num] + cur_original_deployment = original_deployment[ + node_id * per_node_device_num:(node_id + 1) * + per_node_device_num] + + cur_node_rendun_pos = rendun_pos[node_id * + per_node_device_num:(node_id + + 1) * + per_node_device_num] + + cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( + layer_workloads, cur_node_weights, cur_original_deployment, + expert_from_device, node_num, is_node_redundant, + cur_node_rendun_pos) + device_assignments += cur_device_assignments + device_weights += cur_device_weights + device_loads += cur_device_loads + device_counts += cur_device_counts + com_between_devices += cur_com_between_devices + + else: + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( + layer_workloads, weights, original_deployment, + expert_from_device, node_num, is_node_redundant, rendun_pos) + report, max_load = self.generate_allocation_report( + device_assignments, device_weights, device_loads, device_counts) + + return report, max_load, com_between_devices + + @staticmethod + def two_device_exchange_experts(cur_device_result, exchange_device_result, + cur_exchanged_expert_id, + next_exchanged_expert_id, ave_workload, + increment, num_redundancy_expert): + + cur_device_weight = cur_device_result['expert_weights'] + next_device_weight = exchange_device_result['expert_weights'] + + cur_device_expert_id = cur_device_result['assigned_experts'] + next_device_expert_id = exchange_device_result['assigned_experts'] + + cur_device_total_weight = cur_device_result['total_load'] + next_device_total_weight = exchange_device_result['total_load'] + max_weight = max(cur_device_total_weight, next_device_total_weight) + + cur_exchange_index = -1 + next_exchange_index = -1 + + for index, weight in enumerate(cur_device_weight): + for next_index, next_weight in enumerate(next_device_weight): + change_flag = True + if (cur_device_expert_id[index] in next_device_expert_id + or next_device_expert_id[next_index] + in cur_device_expert_id): + change_flag = False + if (cur_device_expert_id[index] not in cur_exchanged_expert_id + ) and (next_device_expert_id[next_index] + not in next_exchanged_expert_id) and change_flag: + + cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight + next_total_weight_after_exchange = next_device_total_weight - next_weight + weight + exchange_max_weight = max( + cur_total_weight_after_exchange, + next_total_weight_after_exchange) + if exchange_max_weight < max_weight and ( + max_weight - + exchange_max_weight) >= (ave_workload * increment): + max_weight = exchange_max_weight + cur_exchange_index = index + next_exchange_index = next_index + + return cur_exchange_index, next_exchange_index + + def expert_exchange_between_devices(self, + ave_workload, + increment, + cur_layer_result, + com_between_devices, + num_redundancy_expert, + node_idx=0, + per_node_device_num=0, + is_node_redundant=False): + + if is_node_redundant: + cur_devices_result = cur_layer_result[node_idx * + per_node_device_num: + (node_idx + 1) * + per_node_device_num] + else: + cur_devices_result = cur_layer_result + + devices_total_weight = [] + for device in cur_devices_result: + devices_total_weight.append( + (device['total_load'], device['device_id'] - 1)) + + exchange_frequency = 100 + while exchange_frequency > 0: + exchange_frequency -= 1 + devices_total_weight.sort(key=lambda x: x[0]) + max_weight_device_id = devices_total_weight[-1][1] + exchange = False + for index in range(0, len(devices_total_weight) - 1): + min_weight_device_id = devices_total_weight[index][1] + if min_weight_device_id not in com_between_devices[ + max_weight_device_id]: + cur_exchanged_expert_id = list( + com_between_devices[max_weight_device_id].values()) + next_exchanged_expert_id = list( + com_between_devices[min_weight_device_id].values()) + + cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( + cur_layer_result[max_weight_device_id], + cur_layer_result[min_weight_device_id], + cur_exchanged_expert_id, next_exchanged_expert_id, + ave_workload, increment, num_redundancy_expert) + + if cur_exchange_index != -1: + self.exchange_expert(cur_exchange_index, + next_exchange_index, + max_weight_device_id, + min_weight_device_id, + cur_layer_result, + com_between_devices) + + devices_total_weight[-1] = ( + cur_layer_result[max_weight_device_id] + ['total_load'], max_weight_device_id) + devices_total_weight[index] = ( + cur_layer_result[min_weight_device_id] + ['total_load'], min_weight_device_id) + exchange = True + break + + if not exchange: + break + + def exchange_experts(self, layer_result, layer_com_between_devices, + num_nodes, device_num, is_node_redundant, + ave_workload, increment, num_redundancy_expert, + org_deployment): + + global_deployment = [] + + if is_node_redundant: + per_node_device_num = self.safe_exact_divide(device_num, num_nodes) + for node_idx in range(num_nodes): + self.expert_exchange_between_devices( + ave_workload, increment, layer_result, + layer_com_between_devices, num_redundancy_expert, node_idx, + per_node_device_num, is_node_redundant) + else: + self.expert_exchange_between_devices(ave_workload, increment, + layer_result, + layer_com_between_devices, + num_redundancy_expert) + + max_workload = 0 + for box in layer_result: + global_deployment.append(box['assigned_experts']) + if max_workload < box['total_load']: + max_workload = box['total_load'] + + global_deployment = np.array(global_deployment) + + return global_deployment, max_workload + + def count_elements(self, lst): + count = 0 + for item in lst: + if isinstance(item, list): + count += self.count_elements(item) + else: + count += 1 + return count + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, + global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [ + int(x) for x in current_expert_table[layer_id][card_id] + ] + new_list = [ + int(x) for x in global_deployment[layer_id][card_id] + ] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[ + j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + def rebalance_experts(self, + current_expert_table, + expert_workload, + is_node_redundant=False, + increment=0.01): + info = DynamicTable() + info.workload_table = expert_workload.numpy() + info.placement_table = current_expert_table.numpy() + assert info.workload_table is not None + layer_num, num_npus, experts_per_npu = info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], + return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, + info.workload_table, + num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer( + info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + num_node = self.safe_exact_divide(num_npus, 8) + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + expert_from_device = np.zeros((layer_num, num_original_expert)) + + if num_original_expert != expert_num: + raise ValueError( + f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})" + ) + + if num_npus <= 0: + raise ValueError("The number of NPUs must be greater than 0") + + if num_npus < num_redundancy_expert: + raise ValueError( + f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})" + ) + + global_deployment: list[list[list[int]]] = [[[] + for _ in range(num_npus)] + for _ in range(layer_num)] + layer_initial_imbalance = self.calculate_initial_imbalance( + info.placement_table, layer_workloads) + max_heat_per_layer_after = np.zeros([layer_num]) + sum_num = 0 + for layer in range(layer_num): + # print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer]) + if layer_initial_imbalance[layer] < 1.01: + global_deployment[layer] = info.placement_table[layer] + continue + + ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), + num_npus) + + rendun_pos: list[list[int]] = [[] for _ in range(num_npus)] + existing_experts = set() + for device_id, device in enumerate(info.placement_table[layer]): + for index, expert_id in enumerate(device): + if expert_id not in existing_experts: + existing_experts.add(expert_id) + expert_from_device[layer][expert_id] = device_id + else: + rendun_pos[device_id].append(index) + + result, max_workload, com_between_devices = self.redundant_expert_deployment( + layer_workloads[layer], info.placement_table[layer], + expert_from_device[layer], num_node, is_node_redundant, + rendun_pos) + # print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload)) + + global_deployment[layer], new_max_workload = self.exchange_experts( + result, com_between_devices, num_node, num_npus, + is_node_redundant, ave_workload, increment, + num_redundancy_expert, info.placement_table[layer]) + # print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload)) + + for device_id in range(num_npus): + com_between_devices[device_id] = { + key: value + for key, value in com_between_devices[device_id].items() + } + sum_num += self.count_elements(com_between_devices[device_id]) + + max_heat_per_layer_after[layer] = max( + result, key=lambda x: x['total_load'])['total_load'] + + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append( + self.safe_divide(max_heat_per_layer_after[layer_idx], + max_heat_per_layer_before[layer_idx])) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + new_global_deployment = self.constraint_expert_local_exchange( + current_expert_table, global_deployment) + + return change, per_layer_priority, np.array( + new_global_deployment).tolist() diff --git a/vllm_ascend/eplb/core/policy/policy_factory.py b/vllm_ascend/eplb/core/policy/policy_factory.py new file mode 100644 index 0000000..bbf7315 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_factory.py @@ -0,0 +1,33 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory. +from .policy_abstract import DynamicConfig, EplbPolicy +from .policy_dynamic_ep import DynamicEplb +from .policy_dynamic_ep_v2 import DynamicEplbV2 +from .policy_flashlb import FlashLB +from .policy_random import RandomLoadBalance + + +class PolicyFactory: + + @staticmethod + def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: + policy = { + # Constraint applying Dynamic EPLB policy V2: + # If there exists redundant expert: + # only one redundant expert can be placed in one NPU and its physical expert index must be 0 + + # Applying greedy d2d expert weight update composing + 0: + RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3 + 1: + DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load + 2: + DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle + 3: + FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment + } + policy_class = policy.get(policy_type, RandomLoadBalance) + policy_instance = policy_class(config) + if policy_type == 3: + policy_instance.warm_up() + return policy_instance \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/policy_flashlb.py b/vllm_ascend/eplb/core/policy/policy_flashlb.py new file mode 100644 index 0000000..2bf6551 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_flashlb.py @@ -0,0 +1,651 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. + +import logging +from collections import deque +from typing import Dict + +import numpy as np +import torch +from numba import njit # type: ignore + +from .policy_abstract import DynamicConfig, EplbPolicy + +numba_logger = logging.getLogger("numba") +numba_logger.setLevel(logging.WARNING) + + +@njit +def compute_piece_counts(X, P, stage_weights): + n_stage, N = X.shape + S = P - N + pieces = np.ones(N, dtype=np.int32) + unit = X / pieces # unit[i, j] = X[i, j] / pieces[j] + + for _ in range(S): + deltas = np.zeros(N, dtype=np.float32) + for i in range(n_stage): + # Find top1 and top2 + idx1 = -1 + idx2 = -1 + val1 = -1.0 + val2 = -1.0 + for j in range(N): + v = unit[i, j] + if v > val1: + val2 = val1 + idx2 = idx1 + val1 = v + idx1 = j + elif v > val2: + val2 = v + idx2 = j + + origin = unit[i, idx1] + secv = unit[i, idx2] + alt = X[i, idx1] / (pieces[idx1] + 1) + delta = origin - (alt if alt > secv else secv) + deltas[idx1] += delta * stage_weights[i] if np.any( + delta) != 0 else stage_weights[i] + + max_idx = np.argmax(deltas) + pieces[max_idx] += 1 + for i in range(n_stage): + unit[i, max_idx] = X[i, max_idx] / pieces[max_idx] + + # Compute max load + max_load = 0.0 + for j in range(N): + total = 0.0 + for i in range(n_stage): + total += unit[i, j] + if total > max_load: + max_load = total + + return pieces + + +@njit +def jsq_placement(X, pieces, M, stage_weights): + n_stage, N = X.shape + total_piece = pieces.sum() + num_per_group = total_piece // M + + # 1. Compute unit_hotness + unit_hotness = np.empty((n_stage, N), dtype=np.float32) + for i in range(N): + if pieces[i] > 0: + for s in range(n_stage): + unit_hotness[s, i] = X[s, i] / pieces[i] + else: + for s in range(n_stage): + unit_hotness[s, i] = 0.0 + + # 2. Sort by total hotness + scores = np.zeros(N, dtype=np.float32) + for i in range(N): + for s in range(n_stage): + scores[i] += unit_hotness[s, i] + idx = np.argsort(-scores) + + # 3. Initialization + loads = np.zeros((n_stage, M), dtype=np.float32) + dev_phy_exp_n = np.zeros(M, dtype=np.int32) + deployment = -np.ones((M, num_per_group), dtype=np.int32) + dep_ptr = np.zeros(M, dtype=np.int32) + + # 4. Main loop + for t in range(N): + i = idx[t] + used_device = list() + for _ in range(pieces[i]): + # 4.1 Construct w vector + w = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + w[s] = unit_hotness[s, i] + + # 4.2 Compute stage-level maximum load + stage_max = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + max_val = loads[s, 0] + for k in range(1, M): + if loads[s, k] > max_val: + max_val = loads[s, k] + stage_max[s] = max_val + + # 4.3 Compute denominator + denom = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + sum_tmp = 0.0 + for j in range(M): + sum_tmp += loads[s, j] + w[s] + denom[s] = sum_tmp / M + 1e-2 + + # 4.4 Find best device j + best_j = -1 + best_val = 1e30 + for j in range(M): + if dev_phy_exp_n[j] >= num_per_group: + continue + if j in used_device: + continue + score = 0.0 + for s in range(n_stage): + tmp_sj = loads[s, j] + w[s] + numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s] + score += stage_weights[s] * (numer_sj / denom[s]) + if score < best_val: + best_val = score + best_j = j + if best_j == -1: + continue + + used_device.append(best_j) + + # 4.5 Update status + for s in range(n_stage): + loads[s, best_j] += w[s] + ptr = dep_ptr[best_j] + deployment[best_j, ptr] = i + dep_ptr[best_j] += 1 + dev_phy_exp_n[best_j] += 1 + + # Handle remaining -1 values: fill with random elements from range(N) not in current column + for rank in range(M): + for col in range(num_per_group): + if deployment[rank, col] == -1: + # Get elements already in current column + current_rank_elements = set(deployment[rank, :]) + # Filter elements from range(N) not in current column + available = [ + x for x in range(N) if x not in current_rank_elements + ] + # Randomly select an available element to fill + if len(available) > 0: + rand_idx = np.random.randint(0, len(available)) + deployment[rank, col] = available[rand_idx] + elif N > 0: + # All unique experts are already in this rank's column, so we can pick any expert randomly. + deployment[rank, col] = np.random.randint(0, N) + + return deployment + + +@njit +def slice_values(X, pieces): + total_len = 0 + for i in range(X.shape[0]): + total_len += pieces[i] + result = np.empty(total_len, dtype=np.float32) + idx = 0 + for i in range(X.shape[0]): + val = X[i] / pieces[i] + for _ in range(pieces[i]): + result[idx] = val + idx += 1 + return result + + +@njit +def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, + simulated_deployment, stage_weights): + n_stage, N = X.shape + num_group = P // M + + X_all = np.zeros(N, dtype=np.float32) + for i in range(n_stage): + for j in range(N): + X_all[j] += X[i, j] + + sort_idx = np.argsort(np.negative(X_all)) + X_sorted = X[:, sort_idx] + + unit_load = np.empty(N, dtype=np.float32) + for j in range(N): + unit_load[j] = X_all[j] / simulated_pieces[j] + + flat_deployment = simulated_deployment.reshape(-1) + simulated_load = np.zeros(M, dtype=np.float32) + for i in range(flat_deployment.shape[0]): + simulated_load[i // (flat_deployment.shape[0] // + M)] += unit_load[flat_deployment[i]] + + slice_vals = slice_values(X_all, simulated_pieces) + sorted_slices = np.sort(slice_vals)[::-1] + simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M + + cumulative_slices_used = np.zeros(N, dtype=np.int32) + acc = 0 + for i in range(N): + acc += simulated_pieces[sort_idx[i]] + cumulative_slices_used[i] = acc + + group_boundary_indices = np.zeros(num_group, dtype=np.int32) + for i in range(1, num_group + 1): + for j in range(N): + if cumulative_slices_used[j] >= i * M: + group_boundary_indices[i - 1] = j + break + + slices_used_per_group = np.zeros(num_group, dtype=np.int32) + slices_used_per_group[0] = group_boundary_indices[0] + for i in range(1, num_group): + slices_used_per_group[ + i] = group_boundary_indices[i] - group_boundary_indices[i - 1] + slices_used_per_group = M - slices_used_per_group + + loads = np.zeros(M, dtype=np.float32) + pieces = np.zeros(N, dtype=np.int32) + num_remain_slice = P - N + current_idx = 0 + + for g in range(num_group): + window = X_sorted[:, current_idx:current_idx + 2 * M] + low = max(0, current_idx + M - N) + high = min(num_remain_slice, M - 1) + + while (high - low) > 1: + mid = int((high + low) // 2) + keep = M - mid + current_group = window[:, :keep] + current_pieces = compute_piece_counts(current_group, M, + stage_weights) + current_pieces = np.maximum(current_pieces, 1) + current_slice = slice_values(current_group.sum(0), current_pieces) + current_slice_sorted = np.sort(current_slice) + current_loads = loads + current_slice_sorted + current_max: np.float32 = np.max(current_loads) + current_min: np.float32 = np.min(current_loads) + current_slope = (current_max - current_min) / M + next_slope: np.float32 = np.max(simulated_slopes[current_idx + + keep:]) + + if abs(current_slope) > abs(next_slope): + low = mid + else: + high = mid + + S = high + keep = M - S + current_group = window[:, :keep] + current_pieces = compute_piece_counts(current_group, M, stage_weights) + + for i in range(keep): + pieces[sort_idx[current_idx + i]] = current_pieces[i] + + current_slice = slice_values(current_group.sum(0), current_pieces) + current_slice_sorted = np.sort(current_slice) + loads += current_slice_sorted + loads = np.sort(loads)[::-1] + + current_idx += keep + num_remain_slice -= S + + return pieces + + +@njit +def compute_objective(deployment, X, pieces): + M, P = deployment.shape + loads = np.zeros(M) + + for i in range(M): + for j in range(P): + expert = deployment[i, j] + if pieces[expert] == 0: + continue + loads[i] += X[expert] / pieces[expert] + + mean_load = np.mean(loads) + max_load: np.float32 = np.max(loads) + obj = max_load / mean_load + return obj, loads + + +@njit +def auto_fix_new_placement(old_placement, new_placement): + """ + Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both + old_placement and new_placement remain in their original positions from old_placement. + New elements (unique to new_placement) will fill the remaining empty positions. + + Args: + old_placement: Old deployment matrix with shape (num_ranks, num_experts) + new_placement: New deployment matrix to be fixed, must have the same shape as old_placement + + Returns: + fixed_new: adjusted version of the new_placement matrix + """ + num_ranks, num_experts = old_placement.shape + fixed_new = np.empty_like(new_placement) + + max_expert_old = old_placement.max() if num_experts > 0 else 0 + max_expert_new = new_placement.max() if num_experts > 0 else 0 + max_expert = max(max_expert_old, max_expert_new) + + for rank_id in range(num_ranks): + old_row = old_placement[rank_id] + new_row = new_placement[rank_id] + + index_array = np.full((max_expert + 1, num_experts), + -1, + dtype=np.int32) + count_array = np.zeros(max_expert + 1, dtype=np.int32) + + for idx in range(num_experts): + val = old_row[idx] + if val >= 0 and val <= max_expert: + pos = count_array[val] + index_array[val, pos] = idx + count_array[val] += 1 + + old_counter = np.zeros(max_expert + 1, dtype=np.int32) + for idx in range(num_experts): + val = old_row[idx] + if val >= 0 and val <= max_expert: + old_counter[val] += 1 + + retain_elements = np.empty(num_experts, dtype=new_placement.dtype) + new_elements = np.empty(num_experts, dtype=new_placement.dtype) + retain_ptr = 0 + new_ptr = 0 + + for val in new_row: + if val >= 0 and val <= max_expert and old_counter[val] > 0: + retain_elements[retain_ptr] = val + retain_ptr += 1 + old_counter[val] -= 1 + else: + new_elements[new_ptr] = val + new_ptr += 1 + + current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype) + + for i in range(retain_ptr): + val = retain_elements[i] + if val >= 0 and val <= max_expert: + pos = count_array[val] - 1 + if pos >= 0: + idx = index_array[val, pos] + current_fixed[idx] = val + count_array[val] -= 1 + + empty_indices = np.empty(num_experts, dtype=np.int32) + empty_ptr = 0 + for idx in range(num_experts): + if current_fixed[idx] == -1: + empty_indices[empty_ptr] = idx + empty_ptr += 1 + + for i in range(new_ptr): + if i < empty_ptr: + current_fixed[empty_indices[i]] = new_elements[i] + + fixed_new[rank_id] = current_fixed + + return fixed_new + + +class FlashLB(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + self.par_history: Dict[int, float] = {} + self.hotness_window: Dict[int, deque[float]] = {} + self.max_stage_window = (config.max_stage_window if hasattr( + config, "max_stage_window") else 1) + self.buffer_expert_layer_num = ( + config.buffer_expert_layer_num if hasattr( + config, "buffer_expert_layer_num") else 58) + self.threshold_ratio = (config.threshold_ratio if hasattr( + config, "threshold_ratio") else 0) + + def compute_expert_hotness(self, num_of_expert: int, + deployment: np.ndarray, rank_load: np.ndarray): + hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) + deployment_flat = deployment.ravel() + rank_load_flat = rank_load.ravel() + np.add.at(hotness, deployment_flat, rank_load_flat) + return hotness + + def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray): + n_stage, N = hotness.shape + if np.any(deployment < 0): + print(f"Invalid deployment with negative values: {deployment}") + raise ValueError("Deployment table contains negative values.") + counts = np.bincount(deployment.reshape(-1), minlength=N) + unit_hotness = np.divide(hotness, + counts, + out=np.zeros_like(hotness, dtype=float), + where=counts != 0) + stage_par = np.zeros(n_stage) + for i in range(n_stage): + stage_load = unit_hotness[i][deployment].sum(-1) + stage_par[i] = stage_load.max() / stage_load.mean() + return stage_par.mean() + + def group_based_adaptive_bloating(self, + X, + P, + M, + stage_weights=None, + recorsive=False): + n_stage, N = X.shape + if stage_weights is None: + stage_weights = np.ones(n_stage, dtype=np.float32) + + if recorsive: + ( + simulated_deployment, + simulated_pieces, + ) = self.group_based_adaptive_bloating(X, + P, + M, + stage_weights, + recorsive=False) + else: + simulated_pieces = compute_piece_counts(X, P, stage_weights) + simulated_deployment = jsq_placement(X, simulated_pieces, M, + stage_weights) + + pieces = group_based_adaptive_bloating_kernel( + X.astype(np.float32), + P, + M, + simulated_pieces.astype(np.int32), + simulated_deployment.astype(np.int32), + stage_weights.astype(np.float32), + ) + + deployment = jsq_placement(X, pieces, M, stage_weights) + + X_all = X.sum(0) + unit_load = np.divide(X_all, + pieces, + out=np.zeros_like(X_all, dtype=float), + where=pieces != 0) + load = unit_load[deployment].sum(-1) + + sim_unit_load = X_all / simulated_pieces + sim_load = sim_unit_load[simulated_deployment].sum(-1) + + if load.max() > sim_load.max(): + return simulated_deployment, simulated_pieces + return deployment, pieces + + def need_update(self, current_par, layer_id=0): + threshold = self.par_history.get(layer_id, 0.0) + return current_par >= self.threshold_ratio * threshold + + def compute_stage_weight(self, hotness): + n_stage = hotness.shape[0] + stage_weights = np.zeros(n_stage) + for i in range(n_stage): + stage_weights[i] = hotness[i].sum() + + stage_weights = stage_weights / stage_weights.max() + return stage_weights + + def rebalance_layer(self, deployment, hotness, layer_id=0): + num_rank, expert_per_rank = deployment.shape + num_expert = np.unique(deployment.reshape(-1)).shape[0] + num_of_redundant_expert = num_rank * expert_per_rank - num_expert + + current_par = self.compute_rank_load(deployment, hotness) + + if not self.need_update(current_par, layer_id): + return deployment, current_par, current_par + + stage_weights = self.compute_stage_weight(hotness) + new_deployment, _ = self.group_based_adaptive_bloating( + hotness, + num_expert + num_of_redundant_expert, + num_rank, + stage_weights, + recorsive=False, + ) + if np.any(new_deployment < 0): + print(f"{new_deployment=}") + new_par = self.compute_rank_load(new_deployment, hotness) + + return new_deployment, new_par, current_par + + def register_hotness(self, deployment, rank_load, num_layer, num_expert): + for layer in range(num_layer): + if layer not in self.hotness_window: + self.hotness_window[layer] = deque( + maxlen=self.max_stage_window) + hotness = self.compute_expert_hotness(num_expert, + deployment[layer], + rank_load[layer]) + self.hotness_window[layer].append(hotness) + + def compress_by_avg_pooling_fast_nd(self, arr, m): + n, d = arr.shape + idx = (np.arange(n) * m // n) + result = np.zeros((m, d)) + counts = np.zeros((m, 1)) + np.add.at(result, idx, arr) + np.add.at(counts, idx, 1) + return result / counts + + def rebalance_experts(self, current_expert_table, expert_workload): + current_deployment = np.array(current_expert_table) + expert_workload = np.array(expert_workload) + expert_workload += 1 + num_layer = expert_workload.shape[0] + num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] + self.register_hotness(current_deployment, expert_workload, num_layer, + num_expert) + + new_deployment = current_deployment.copy() + + layers_need_update = np.arange(num_layer) + + new_par = np.zeros(layers_need_update.shape[0]) + current_par = np.zeros(layers_need_update.shape[0]) + for i, layer in enumerate(layers_need_update): + hotness = np.array(self.hotness_window[layer]) + if hotness.shape[0] > self.max_stage_window: + hotness = self.compress_by_avg_pooling_fast_nd( + hotness, self.max_stage_window) + + ( + new_deployment[layer], + new_par[i], + current_par[i], + ) = self.rebalance_layer(current_deployment[layer], + hotness, + layer_id=layer) + + priority = new_par / current_par + priority_idx = np.argsort(priority) + priority_idx = priority_idx[priority[priority_idx] < + 1][:self.buffer_expert_layer_num] + + if np.all(expert_workload == 1): + for _, layer in enumerate(layers_need_update): + self.hotness_window[layer].pop() + return False, np.array([], dtype=int), current_deployment + change = len(priority_idx) > 0 + if change: + for idx in priority_idx: + self.par_history[layers_need_update[idx]] = new_par[idx] + + layers_need_update = priority_idx + deployment = current_deployment + for layer in layers_need_update: + deployment[layer] = auto_fix_new_placement( + current_deployment[layer], new_deployment[layer]) + + return change, layers_need_update, deployment + + +def generate_layered_experts(num_layers=58, + layer_shape=(32, 9), + expert_min=0, + expert_max=255): + """ + Generate expert deployment matrix meeting the following conditions: + - Total of num_layers layers + - Each layer has shape layer_shape (32,9) + - Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer + + Args: + num_layers: Number of layers, default 58 + layer_shape: Shape of a single layer, default (32,9) + expert_min: Minimum expert ID, default 0 + expert_max: Maximum expert ID, default 255 + Returns: + torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1]) + """ + # 1. Basic parameter calculation + expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) + layer_total = layer_shape[0] * layer_shape[ + 1] # Total elements in a single layer: 32*9=288 + extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 + + # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) + assert layer_total >= expert_num, ( + f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, " + "cannot cover all experts") + + # 3. Generate layers one by one + layers = [] + for _ in range(num_layers): + # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) + full_experts = torch.arange(expert_min, + expert_max + 1, + dtype=torch.int64) # shape (256,) + + # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) + extra_experts = torch.randint(expert_min, + expert_max + 1, + size=(extra_slots, ), + dtype=torch.int64) # shape (32,) + + # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) + layer_flat = torch.cat([full_experts, extra_experts], + dim=0) # shape (288,) + # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) + shuffle_idx = torch.randperm(layer_flat.shape[0]) + layer_shuffled = layer_flat[shuffle_idx] + + # 3.4 Reshape to layer_shape (32,9) + layer = layer_shuffled.reshape(layer_shape) + layers.append(layer) + + # 4. Stack all layers to get the final tensor + return torch.stack(layers, dim=0) # shape (58,32,9) + + +def warm_up(): + exam_config = DynamicConfig() + exam_config.ep_worldsize = 32 + exam_config.num_die_per_host = 16 + algo = FlashLB(exam_config) + # Generate target tensor + expert_tensor = generate_layered_experts(num_layers=58, + layer_shape=(32, 9)) + + algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9))) diff --git a/vllm_ascend/eplb/core/policy/policy_random.py b/vllm_ascend/eplb/core/policy/policy_random.py new file mode 100644 index 0000000..558d653 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_random.py @@ -0,0 +1,30 @@ +# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +import copy +import random + +from .policy_abstract import DynamicConfig, EplbPolicy + +random.seed(42) + + +class RandomLoadBalance(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + def rebalance_experts(self, current_expert_table, expert_workload): + new_table = copy.deepcopy(current_expert_table) + num_layers = len(current_expert_table) + + for i in range(num_layers): + # randomly choose two card + # indices = random.sample(range(num_card), 2) + indices = [3, 1] + + # swap redundant experts + expert_id_to_exchange = new_table[i][indices[0]][-1].clone() + new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1] + new_table[i][indices[1]][-1] = expert_id_to_exchange + + return 1, [-i for i in range(num_layers)], new_table diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py new file mode 100644 index 0000000..1f25f8f --- /dev/null +++ b/vllm_ascend/eplb/eplb_updator.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator. +import numpy +import torch +import torch.distributed as dist +import vllm.envs as envs +from vllm.logger import logger + +from vllm_ascend.eplb.core.eplb_worker import EplbProcess + + +class EplbUpdator: + + def __init__(self, ascend_config, loader, eplb_process: EplbProcess, + process): + self.ascend_config = ascend_config + self.init_eplb(self.ascend_config.expert_map_path, process) + self.eplb_loader = loader + self.eplb_process = eplb_process + self.shared_dict = self.eplb_process.shared_dict + + def set_adaptor(self, adaptor): + self.adaptor = adaptor + self.num_moe_layers = self.adaptor.num_moe_layers + self.global_expert_num = self.adaptor.global_expert_num + + def init_eplb(self, expert_map_path, process): + self.rank_id = dist.get_rank() + self.num_expert_load_gather = 10 + self.periodic_load_gather = True + self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update + self.expert_map_path = expert_map_path + self.expert_map_record_path = self.ascend_config.expert_map_record_path + + try: + if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + except Exception: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + + self.expert_map_initialized = False + self.gate_eplb = self.ascend_config.gate_eplb + + self.reqs = [] + self.update_info_all = [] + + self.cur_iterations: torch.int64 = 0 + + self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations + + self.process = process + + logger.info( + f"[ModelRunner] Launched EPLB process (pid={self.process.pid})") + + def update_iteration(self): + self.cur_iterations += 1 + if self.cur_iterations == (self.num_iterations_eplb_update + \ + self.num_wait_worker_iterations + self.num_moe_layers): + if self.expert_map_record_path is not None: + self.adaptor._export_tensor_to_file( + self.shared_dict["expert_maps"], + self.expert_map_record_path) + + self.adaptor.model.clear_all_moe_loads() + if not self.gate_eplb: + self.cur_iterations = 0 + + def get_update_info_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update + + self.num_wait_worker_iterations - 1) + + def wakeup_eplb_worker_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update - 1) + + def update_expert_weight_flag(self): + weight_update_counter = self.cur_iterations - ( + self.num_iterations_eplb_update + self.num_wait_worker_iterations) + return (weight_update_counter >= 0 + and weight_update_counter < self.num_moe_layers) + + def get_init_expert_map(self): + try: + if not self.expert_map_initialized: + self.shared_dict[ + "expert_maps"] = self.adaptor.get_init_expert_map_from_file( + self.num_moe_layers, self.expert_map_path) + self.expert_map_initialized = True + except Exception as e: + logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", + exc_info=True) + + def wakeup_eplb_worker(self): + self.eplb_process.planner_q.put(1) + + def forward_before(self): + if self.update_expert_weight_flag(): + (expert_send_info, expert_recv_info, updated_expert_map, + log2phy_map, layer_id) = self.update_info_all.pop(0) + log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) + self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) + updated_expert_map_this_rank = torch.from_numpy( + numpy.array(updated_expert_map)) + self.eplb_loader.generate_expert_d2d_transfer_task( + expert_send_info, expert_recv_info, + updated_expert_map_this_rank, + layer_id + self.adaptor.num_dense_layers) + + # set asynchronous stream for d2d expert weight update + self.reqs = [] + self.eplb_loader.asyn_expert_weight_transfer(self.reqs) + + def take_update_info_from_eplb_process(self): + # Batch after eplb process being triggered, get update info provided by eplb process + if self.get_update_info_flag(): + self.update_info_all = self.eplb_process.block_update_q.get() + + def forward_end(self): + if self.wakeup_eplb_worker_flag(): + self.compute_and_set_moe_load(is_clear=True) + self.wakeup_eplb_worker() + + if self.update_expert_weight_flag(): + self.eplb_loader.update_expert_map_and_weight(self.reqs) + + self.update_iteration() + + def compute_and_set_moe_load(self, is_clear=False): + local_load = self.adaptor.get_rank_expert_workload() + + self._gather_buffer = None + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug( + f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" + ) + else: + moe_load = local_load.unsqueeze(1) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug( + f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" + ) + return moe_load + + def warm_up_eplb(self): + + self.get_init_expert_map() + self.compute_and_set_moe_load() + + src_tensor = torch.empty((1, ), device=self.device) + self_rank = dist.get_rank() + + comm_op_list = [] + + for dst_rank in range(self.world_size): + if dst_rank == self_rank: + continue + comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + + for src_rank in range(self.world_size): + if src_rank == self_rank: + continue + comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank)) + if comm_op_list: + reqs = dist.batch_isend_irecv(comm_op_list) + + for req in reqs: + req.wait() + + def shutdown(self): + """ + Clean up the EPLB process. + """ + if self.process.is_alive(): + self.process.terminate() + self.process.join() + logger.info("[ModelRunner] EPLB process terminated") diff --git a/vllm_ascend/eplb/utils.py b/vllm_ascend/eplb/utils.py new file mode 100644 index 0000000..71b4487 --- /dev/null +++ b/vllm_ascend/eplb/utils.py @@ -0,0 +1,77 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register. +import types + +import torch + + +def get_expert_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_map() + + +def get_log2phy_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_log2phy_map() + + +def get_all_expert_map(self, num_moe_layers): + all_loads = [] + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + for layer_id in range(num_moe_layers): + load_tensor = self.get_expert_map( + layer_id + num_dense_layers) # (num_experts_per_layer,) + all_loads.append(load_tensor) + + return torch.stack(all_loads, dim=0) + + +def get_all_moe_loads(self): + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + all_moe_loads = torch.stack( + [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ + for layer_id in range(self.num_moe_layers)], + dim=0 + ) + return all_moe_loads + + +def clear_all_moe_loads(self): + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + for layer_id in range(self.num_moe_layers): + self.model.layers[layer_id + + num_dense_layers].mlp.experts.clear_moe_load() + + +def model_register(model, model_config): + model.get_expert_map = types.MethodType(get_expert_map, model) + model.get_log2phy_map = types.MethodType(get_log2phy_map, model) + model.get_all_expert_map = types.MethodType(get_all_expert_map, model) + model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model) + model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model) + + config = model_config.hf_config + + if config.model_type == "qwen3_moe": + model.num_moe_layers = config.num_hidden_layers + elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": + num_dense_layers = config.first_k_dense_replace + model.num_moe_layers = config.num_hidden_layers - num_dense_layers + else: + raise NotImplementedError("EPLB is not supported.") diff --git a/vllm_ascend/lora/punica_wrapper/lora_ops.py b/vllm_ascend/lora/lora_ops.py similarity index 78% rename from vllm_ascend/lora/punica_wrapper/lora_ops.py rename to vllm_ascend/lora/lora_ops.py index e8bf8ad..58d0ea6 100644 --- a/vllm_ascend/lora/punica_wrapper/lora_ops.py +++ b/vllm_ascend/lora/lora_ops.py @@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - return torch.ops._C.bgmv_shrink( + return torch.ops._C_ascend.bgmv_shrink( inputs, lora_a_weights, lora_indices_tensor, @@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - return torch.ops._C.bgmv_expand( + return torch.ops._C_ascend.bgmv_expand( inputs, lora_b_weights, lora_indices_tensor, @@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - return torch.ops._C.bgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, output_tensor, - slice_offset, slice_size) + return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, output_tensor, + slice_offset, slice_size) def sgmv_shrink( @@ -69,9 +69,9 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, scaling) + return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, scaling) def sgmv_expand(inputs: torch.Tensor, @@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor, max_seq_length: int, token_nums: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand( + return torch.ops._C_ascend.sgmv_expand( inputs, lora_b_weights, lora_indices_tensor, @@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, slice_offset, slice_size) + return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, slice_offset, + slice_size) diff --git a/vllm_ascend/lora/punica_wrapper/punica_npu.py b/vllm_ascend/lora/punica_npu.py similarity index 94% rename from vllm_ascend/lora/punica_wrapper/punica_npu.py rename to vllm_ascend/lora/punica_npu.py index a85c837..db4adc4 100644 --- a/vllm_ascend/lora/punica_wrapper/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -11,12 +11,14 @@ if is_310p(): bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) else: - from vllm_ascend.lora.punica_wrapper.lora_ops import ( - bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) + from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm_ascend.lora.utils import refresh_all_lora_classes + # The platforms that are compatible with the PyTorch-native implementation can # inherit this class @@ -31,6 +33,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + refresh_all_lora_classes() def _shrink_prefill( self, @@ -338,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - - if lora_a_stacked.dim() == 2: - lora_a_stacked = lora_a_stacked.unsqueeze(0) - if lora_b_stacked.dim() == 2: - lora_b_stacked = lora_b_stacked.unsqueeze(0) - - r = lora_a_stacked.size(-1) + r = lora_b_stacked.size(-1) if buffer is None: buffer = torch.zeros((x.size(0), r), @@ -352,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase): device=x.device) indices = self.sampler_indices - if indices.max() >= lora_a_stacked.size(0): - indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1) - lora_a_reshaped = lora_a_stacked.transpose(1, 2) - lora_b_reshaped = lora_b_stacked.transpose(1, 2) - - bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale) - bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True) + bgmv_shrink(x, lora_a_stacked, buffer, indices, scale) + bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) y = y.view_as(y_org) diff --git a/vllm_ascend/lora/utils.py b/vllm_ascend/lora/utils.py new file mode 100644 index 0000000..be4fbeb --- /dev/null +++ b/vllm_ascend/lora/utils.py @@ -0,0 +1,110 @@ +from typing import Optional + +import vllm +from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +from vllm.lora.layers.utils import _not_fully_sharded_can_replace + +from vllm_ascend.ops.linear import (AscendColumnParallelLinear, + AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, + AscendRowParallelLinear) +from vllm_ascend.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding + + +class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendColumnParallelLinear + + +class AscendMergedColumnParallelLinearWithLoRA( + MergedColumnParallelLinearWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendMergedColumnParallelLinear + + +class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendRowParallelLinear + + +class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendVocabParallelEmbedding + + +class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA): + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is AscendQKVParallelLinear and len( + packed_modules_list) == 1 + + +class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is AscendQKVParallelLinear + and len(packed_modules_list) == 3) + + +def refresh_all_lora_classes(): + vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add( + AscendMergedColumnParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add( + AscendMergedQKVParallelLinearWithLoRA) diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index 47c7758..9a58afd 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -23,7 +23,7 @@ from torch.library import Library # Do NOT perform any real computation or allocate device memory. # # 2. Register your meta function using `register_meta_if_necessary`, providing: -# - The namespace (usually "_C" for custom ops) +# - The namespace (usually "_C_ascend" for custom ops) # - The operator name (as registered in C++) # - The Python meta function # - (Optional) The overload name, if your op has overloads @@ -39,7 +39,7 @@ from torch.library import Library # # For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors -lib = Library("_C", "IMPL") +lib = Library("_C_ascend", "IMPL") def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): @@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, return y_out -register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta) -register_meta_if_necessary("_C", "get_masked_input_and_mask", +register_meta_if_necessary("_C_ascend", "rotary_embedding", + rotary_embedding_meta) +register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta) -register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta) -register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index dfb47fe..8577abe 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -4,23 +4,20 @@ import vllm_ascend.envs as envs_ascend def register_model(): - from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401 - from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 - from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 - from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401 - from .qwen2_5_vl import \ - AscendQwen2_5_VLForConditionalGeneration # noqa: F401 - from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 - from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 - - ModelRegistry.register_model( - "DeepSeekMTPModel", - "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") - ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") + ModelRegistry.register_model( + "Qwen3VLMoeForConditionalGeneration", + "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration" + ) + + ModelRegistry.register_model( + "Qwen3VLForConditionalGeneration", + "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration" + ) + if envs_ascend.USE_OPTIMIZED_MODEL: ModelRegistry.register_model( "Qwen2_5_VLForConditionalGeneration", @@ -32,30 +29,32 @@ def register_model(): "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" ) - if envs_ascend.VLLM_ASCEND_ENABLE_DBO: - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") - ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") - else: - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") - ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM") + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + + ModelRegistry.register_model( + "DeepSeekMTPModel", + "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") - ModelRegistry.register_model( - "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") - + # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization + # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. ModelRegistry.register_model( "PanguProMoEForCausalLM", - "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") + "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" + ) + ModelRegistry.register_model( + "Qwen3NextForCausalLM", + "vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM") diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py deleted file mode 100644 index 9469e99..0000000 --- a/vllm_ascend/models/deepseek_dbo.py +++ /dev/null @@ -1,1046 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # Adapted from -# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py -# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py -# """Inference-only DeepseekV2/DeepseekV3 model.""" - -from typing import Any, Dict, Iterable, List, Optional, Union - -import torch -import torch.distributed as dist -import torch_npu # noqa: F401 -from torch import nn -from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - get_tp_group, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import \ - DeepseekV2ForCausalLM # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, - get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.sequence import IntermediateTensors - -import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP, - CustomDeepseekV2RowParallelLinear) -from vllm_ascend.multistream.base import MSEventKey -from vllm_ascend.multistream.context import ( - advance_step_multistream_layer_context, get_multistream_comm_context, - get_multistream_layer_context, set_multistream_context) -from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, - MultiStreamPreTransformerLayer) -from vllm_ascend.multistream.metadata import (MultiStreamConfig, - MultiStreamStepMetadata, - make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.utils import dispose_tensor - -VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO - - -class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): - - def _forward_ms_mlp(self, x): - current_ms_metadata = get_multistream_comm_context() - assert current_ms_metadata is not None - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - x, _ = self.down_proj(x) - current_ms_metadata.after_comm_event.record() - return x - - -class CustomDeepseekDBOMoE(nn.Module): - - top_k: int - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts - self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") - - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) - else: - self.gate.e_score_correction_bias = None - - self.experts = AscendFusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) - - if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - self.shared_experts = CustomDeepseekDBOMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=True, - prefix=f"{prefix}.shared_experts", - ) - CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - - self.params_dtype = torch.get_default_dtype() - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = forward_context.in_profile_run - - is_prefill = forward_context.with_prefill - - old_hidden_states = hidden_states.clone() - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekDBOMoE.top_k, - enable_force_load_balance=enable_force_load_balance, - ) * self.routed_scaling_factor - - if self.n_shared_experts is not None: - shared_output = self.shared_experts(old_hidden_states) - - if shared_output is not None: - hidden_states = hidden_states + shared_output - - return hidden_states - - # ----------------------------------------- TBO-related -------------------------------------------- - def _forward_ms_op_shared_expert( - self, - hidden_states: torch.Tensor, - ): - shared_output = self.shared_experts._forward_ms_mlp(hidden_states) - return shared_output - - def _forward_ms_op_gate( - self, - hidden_states: torch.Tensor, - ): - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - return router_logits - - def _forward_ms_op_tp_allgather( - self, - hidden_states: torch.Tensor, - chunk_hidden_states: torch.Tensor, - num_tokens: int = 0, - ): - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - current_ms_metadata.after_comm_event.record() - return final_hidden_states - - -class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = CustomDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) - else: - hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: - forward_kwargs = {} - output_shape = hidden_states.shape - output = torch.empty(output_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - forward_kwargs['output'] = output - output = self.mla_attn.impl.forward(self.mla_attn, - hidden_states_or_q_c, - hidden_states, None, kv_cache, - attn_metadata, - **forward_kwargs) - output = output.view(-1, output_shape[-1]) - return output - else: - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) - - -class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - self.layer_idx = layer_idx - # TODO: enable mla in vllm-ascend - if model_config.use_mla: - attn_cls = CustomDeepseekDBOMLAAttention - else: - attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = CustomDeepseekDBOMoE( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = CustomDeepseekDBOMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.routed_scaling_factor = config.routed_scaling_factor - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - previous_hidden_states, previous_residual = hidden_states, residual - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - # Dispose hidden_states and residual from the previous layer - # to save npu memory because they're no longer used. - dispose_tensor(previous_hidden_states) - dispose_tensor(previous_residual) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - if hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # We scale both hidden_states and residual before - # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor - if self.layer_idx == 0: - # The residual is shared by all layers, we only scale it on - # first layer. - residual *= 1. / self.routed_scaling_factor - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if isinstance(self.mlp, CustomDeepseekDBOMoE): - hidden_states = self.mlp(hidden_states, attn_metadata) - else: - hidden_states = self.mlp(hidden_states) - - if isinstance( - self.mlp, - CustomDeepseekDBOMLP) and hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor - - return hidden_states, residual - - # ----------------------------------------- TBO-related -------------------------------------------- - def _forward_ms_layer( - self, - positions: List[torch.Tensor], - hidden_states: List[torch.Tensor], - residual: List[torch.Tensor], - attn_metadata: List[AttentionMetadata], - kv_cache: Optional[torch.Tensor] = None, - is_prefill: bool = False, - ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: - layer_index, ms_metadata, _ = get_multistream_layer_context() - assert layer_index >= 0 and ms_metadata is not None - num_micro_batchs = ms_metadata.ms_config.num_micro_batches - assert isinstance(self.mlp, CustomDeepseekDBOMoE) - assert len(positions) == num_micro_batchs - assert len(hidden_states) == num_micro_batchs - assert residual is not None - assert attn_metadata is not None - num_tokens = [] - hidden_dims = [] - shared_outputs = [] - router_logits = [] - chunk_hidden_states = [] - - # block 1 : attention - # block 2 : attn tp communication - # the attn computation of microbatch 1 can be overlapped with the moe - # communication in the previous layer, and the attn computation of microbatch 2 - # can be overlapped with the attn communication of microbatch 1 - for i in range(num_micro_batchs): - # wait last layer moe finishing communication - ms_metadata.try_wait_event(layer_index - 1, i, - MSEventKey.FFN_AR_FINISH) - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.ATTN_COM_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.ATTN_AR_FINISH], - ) - - with set_multistream_context(context, i): - forward_context = get_forward_context() - forward_context.attn_metadata = attn_metadata[i] - - # input layernorm - hidden_states[i], residual[ - i] = self._forward_ms_op_input_layernorm( - hidden_states[i], residual[i]) - # attention and tp allreduce - hidden_states[i], residual[i] = self._forward_ms_op_attn( - positions[i], hidden_states[i], residual[i], kv_cache, - attn_metadata[i]) - - # block 3 : shared experts - # if there is an allreduce ops in shared expert, we can overlap it with the computation of the - # shared expert for next microbatch or moe gating - for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, - MSEventKey.ATTN_AR_FINISH) - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_SE_COMP_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_SE_COMM_FINISH], - ) - with set_multistream_context(context, i): - # compute shared expert after finishing ATTN AR - hidden_states[i], residual[ - i] = self._forward_ms_op_post_attn_layernorm( - hidden_states[i], residual[i]) - - num_token, hidden_dim = hidden_states[i].shape - hidden_states[i] = hidden_states[i].view(-1, hidden_dim) - num_tokens.append(num_token) - hidden_dims.append(hidden_dim) - if self.mlp.n_shared_experts is not None: - # TODO: we can move shared expert computation into next block if reduce results is false - shared_output = self.mlp._forward_ms_op_shared_expert( - hidden_states[i]) - shared_outputs.append(shared_output) - - # block 4 : moe - for i in range(num_micro_batchs): - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata[i] is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata[i].num_prefills > 0 - enable_force_load_balance = False - - if self.mlp.tp_size > 1: - num_token, _ = hidden_states[i].shape - padded_num_tokens = (self.mlp.tp_size - num_tokens[i] % - self.mlp.tp_size) % self.mlp.tp_size - if padded_num_tokens > 0: - hidden_states[i] = nn.functional.pad( - hidden_states[i], (0, 0, 0, padded_num_tokens)) - chunk_hidden_state = torch.tensor_split(hidden_states[i], - self.mlp.tp_size, - dim=0) - chunk_hidden_states.append(chunk_hidden_state) - local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] - else: - local_hidden_states = hidden_states[i] - - router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) - router_logits.append(router_logit) - - if CustomDeepseekDBOMoE.top_k: - real_top_k = CustomDeepseekDBOMoE.top_k - else: - real_top_k = self.mlp.experts.top_k - - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( - local_hidden_states, router_logits[i], is_prefill, real_top_k, - enable_force_load_balance) - - # the following kernels will be submitted to the comm stream to overlap the computation of the - # moe computation of next microbatch and the attn computation of next layer - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_COM_FINISH], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], - ) - context.before_comm_event.record() - with torch.npu.stream(ms_metadata.communicate_stream): - context.before_comm_event.wait() - if self.mlp.experts.reduce_results and ( - self.mlp.experts.tp_size > 1 - or self.mlp.experts.ep_size > 1): - hidden_states[i] = tensor_model_parallel_all_reduce( - hidden_states[i]) - hidden_states[ - i] = hidden_states[i] * self.mlp.routed_scaling_factor - context.after_comm_event.record() - - context = MultiStreamStepMetadata( - comm_stream=ms_metadata.communicate_stream, - before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], - after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_AR_FINISH], - ) - with set_multistream_context(context, i): - if self.mlp.tp_size > 1: - hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( - hidden_states[i], chunk_hidden_states[i], - padded_num_tokens) - with torch.npu.stream(ms_metadata.communicate_stream): - # last - if shared_outputs[i] is not None: - hidden_states[i] = hidden_states[i] + shared_outputs[i] - hidden_states[i] = hidden_states[i].view( - num_tokens[i], hidden_dims[i]) - if isinstance(self.mlp, CustomDeepseekDBOMLP - ) and hidden_states[i].dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states[i] *= 1. / self.routed_scaling_factor - context.after_comm_event.record() - return hidden_states, residual - - # should split ops in Decoder Layer - def _forward_ms_op_input_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - return hidden_states, residual - - def _forward_ms_op_attn( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - if hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # We scale both hidden_states and residual before - # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor - if self.layer_idx == 0: - # The residual is shared by all layers, we only scale it on - # first layer. - residual *= 1. / self.routed_scaling_factor - return hidden_states, residual - - def _forward_ms_op_post_attn_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ): - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - return hidden_states, residual - - -class CustomDeepseekDBOModel(nn.Module): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.first_k_dense_replace = config.first_k_dense_replace - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: CustomDeepseekDBODecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - ), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - # tbo related members - if VLLM_ASCEND_ENABLE_DBO: - self.use_mla = model_config.use_mla - self.multistream_config = MultiStreamConfig() - multistream_metadata = make_multistream_metadata_ds( - start_layer=self.start_layer + self.first_k_dense_replace, - end_layer=self.end_layer, - causal_lm=getattr(config, "causal_lm", True), - multistream_config=self.multistream_config, - ) - self.ms_pre_layer = MultiStreamPreTransformerLayer( - multistream_metadata) - self.ms_post_layer = MultiStreamPostTransformerLayer( - multistream_metadata) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - num_normal_layers = (self.first_k_dense_replace - if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() - else self.end_layer - self.start_layer) - - moe_start_layer = self.start_layer + num_normal_layers - for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, residual, - kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata) - - if moe_start_layer < self.end_layer: - # if we enable multistream/dbo, process sparse layers here - hidden_states, residual = self._forward_ms_layers( - positions=positions, - hidden_states=hidden_states, - residual=residual, - moe_start_layer=moe_start_layer, - kv_caches=kv_caches, - ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def can_run_ms(self): - attn_metadata = get_forward_context().attn_metadata - # enable prefill overlap - return not (attn_metadata is None or attn_metadata.num_prefills == 0 - or not attn_metadata.enable_dbo_across_dp) - - def _forward_ms_layers( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - moe_start_layer: int, - kv_caches: Optional[List[torch.Tensor]] = None, - is_prefill: bool = False, - ): - - if moe_start_layer == self.end_layer: - return hidden_states, residual - - attn_metadata, [positions, hidden_states, - residual] = self.ms_pre_layer( - [positions, hidden_states, residual], ) - # the rest layers - for i in range(moe_start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer._forward_ms_layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - attn_metadata=attn_metadata, - kv_cache=kv_caches[i - self.start_layer] - if kv_caches is not None else None, - is_prefill=is_prefill) - advance_step_multistream_layer_context() - - [hidden_states, - residual] = self.ms_post_layer([hidden_states, residual], ) - return hidden_states, residual - - -class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM): - # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = CustomDeepseekDBOModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # NOTE: This `load_weights` is mainly copied from - # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 - # to fix CI, and it is different from the implementation in main - # TODO: support eplb style load_weights - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - """""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - return_success=False) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 8bcc4fb..0c4f173 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -23,22 +23,20 @@ import torch import torch.nn as nn from transformers import PretrainedConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .deepseek_v2 import CustomDeepseekV2DecoderLayer - class CustomDeepSeekShareHead(SharedHead): @@ -65,6 +63,7 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): quant_config: Optional[QuantizationConfig] = None, ) -> None: nn.Module.__init__(self) + vllm_config = get_current_vllm_config() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -75,10 +74,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): quant_config=quant_config, prefix=maybe_prefix( prefix, "shared_head")) - self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, - model_config, - cache_config, - quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config, + prefix=prefix) def forward( self, @@ -103,8 +100,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=None) hidden_states = residual + hidden_states return hidden_states @@ -171,7 +166,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + sampling_metadata=None, # type: ignore spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) @@ -183,14 +178,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): class CustomDeepSeekMTP(DeepSeekMTP): - # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; - # NOTE 2.The description file generated by the current msmodelslim tool does not have - # MTP layer info. Please manually add it and set the value to FLOAT. - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) @@ -199,8 +186,6 @@ class CustomDeepSeekMTP(DeepSeekMTP): prefix=maybe_prefix( prefix, "model")) - self.sampler = get_sampler() - def forward( self, input_ids: torch.Tensor, @@ -215,4 +200,4 @@ class CustomDeepSeekMTP(DeepSeekMTP): hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, previous_hidden_states, inputs_embeds, spec_step_idx) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 6d0913c..2333c38 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -25,160 +25,44 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Union import torch -import torch_npu from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, split_tensor_along_last_dim, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) -from vllm.distributed.parallel_state import get_dp_group, get_ep_group -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import SiluAndMul + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import \ - DeepseekV2ForCausalLM # noqa: E501 from vllm.model_executor.models.deepseek_v2 import \ yarn_get_mscale # noqa: E501 from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, + DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.utils import (PPMissingLayer, + is_pp_missing_parameter, + maybe_prefix) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.models.layers.mla import AscendMLAModules +from vllm_ascend.models.layers.sfa import (AscendSFAModules, + AscendSparseFlashAttention, Indexer) from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.quantization.quant_config import AscendLinearMethod -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor - - -class CustomDeepseekV2SiluAndMul(SiluAndMul): - - def __init__(self, - *, - weight_scale: Optional[Callable[[], torch.Tensor]] = None): - super().__init__() - self.weight_scale = weight_scale - - def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor, - torch.Tensor]]): - if isinstance(x, tuple): - assert self.weight_scale is not None - # For AscendW8A8DynamicLinearMethod: - # a dynamic scale is passed along with the quantized value. - quantized_x, dynamic_scale = x - return torch_npu.npu_dequant_swiglu_quant( - x=quantized_x, - weight_scale=self.weight_scale(), - activation_scale=dynamic_scale, - activate_left=True, - quant_mode=1) - else: - return super().forward_oot(x) - - -class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear): - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - self.output_sizes = output_sizes - super().__init__(input_size, - sum(output_sizes), - bias=bias, - quant_config=quant_config, - prefix=prefix) - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, loaded_shard_id: int): - # With no support for GGUF format yet. - assert not getattr(param, "is_gguf_weight", False) - assert not getattr(param, "is_gguf_weight_type", False) - - assert loaded_shard_id < len(self.output_sizes) - shard_offset = sum(self.output_sizes[:loaded_shard_id]) - shard_size = self.output_sizes[loaded_shard_id] - shard = param.data.narrow(param.output_dim, shard_offset, shard_size) - - assert shard.size() == loaded_weight.size(), ( - f"Tried to load weights of size {loaded_weight.size()}" - f"to a parameter shard of id {loaded_shard_id} size {shard.size()}" - ) - shard.copy_(loaded_weight) - - -class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): - - def forward( - self, - input_, - is_prefill=True, - is_force_scatter=False - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - if self.reduce_results and self.tp_size > 1: - num_tokens = output_parallel.shape[0] - if is_force_scatter and num_tokens % self.tp_size: - output_parallel = nn.functional.pad( - output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) - if is_force_scatter or (not is_prefill - and output_parallel.shape[0] % self.tp_size - == 0): - output = tensor_model_parallel_reduce_scatter(output_parallel, - dim=0) - else: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias class CustomDeepseekV2RowParallelLinear(RowParallelLinear): @@ -217,215 +101,6 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear): return output, output_bias -class CustomDeepseekV2MLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - force_replicate: bool = False, - prefix: str = "", - ) -> None: - super().__init__() - if not force_replicate: - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") - else: - self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = ReplicatedLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - - quant_method = self.gate_up_proj.quant_method - if isinstance(quant_method, UnquantizedLinearMethod): - self.act_fn = CustomDeepseekV2SiluAndMul() - elif (isinstance(quant_method, AscendLinearMethod) and isinstance( - quant_method.quant_method, AscendW8A8DynamicLinearMethod)): - # TODO(sdmyzlp): Currently preserved as before: - # 1. The only quantization supported for silu is W8A8Dynamic - # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16 - # - # Maybe one can implement a better and more general configuration - # scheme, e.g. by somehow passing around the tweaked `quant_config` - self.act_fn = CustomDeepseekV2SiluAndMul( - # Use lazy binding, for `weight_scale_fp32` is accessible - # only after `process_weights_after_loading`. - weight_scale=lambda: self.gate_up_proj.weight_scale_fp32) - # To be consumed by AscendW8A8DynamicLinearMethod.apply() - self.gate_up_proj._ascend_quant_config = { - "output_dtype": torch.int32, - "pertoken_scale": False, - "return_scale": True, - } - self.down_proj._ascend_quant_config = { - "output_dtype": torch.bfloat16, - "pertoken_scale": True, - "return_scale": False, - } - else: - raise NotImplementedError( - f"Quantization with [{type(quant_method)}] is NOT supported") - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class CustomDeepseekV2MoE(nn.Module): - - top_k: int - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") - - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and \ - self.torchair_graph_enabled - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) - else: - self.gate.e_score_correction_bias = None - - self.experts = AscendFusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) - - if config.n_shared_experts is not None: - self.all_reduce_merge = self.experts.all_reduce_merge - reduce_results = not self.all_reduce_merge - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.shared_experts = CustomDeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=reduce_results, - force_replicate=self.enable_multistream_moe - or enable_shared_expert_dp, - prefix=f"{prefix}.shared_experts", - ) - else: - self.shared_experts = None # type: ignore - CustomDeepseekV2MoE.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - self.kv_consumer = None - transfer_config = get_current_vllm_config().kv_transfer_config - if transfer_config is not None: - self.kv_consumer = transfer_config.kv_role == "kv_consumer" - - self.params_dtype = torch.get_default_dtype() - self.rm_router_logits = self.experts.rm_router_logits - - def forward(self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None, - replace_allreduce: bool = False) -> torch.Tensor: - - forward_context = get_forward_context() - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - - enable_force_load_balance = forward_context.in_profile_run - - is_prefill = forward_context.with_prefill - - # If this node is kv_consumer, we force the moe always runs in decode path to make sure - # the behaviour aligned between dummy_run and normal model_execute. - if self.kv_consumer: - is_prefill = False - enable_force_load_balance = False - - # router_logits: (num_tokens, n_experts) - router_logits = None - if not self.rm_router_logits and not self.enable_multistream_moe: - router_logits, _ = self.gate(hidden_states) - - experts_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekV2MoE.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=self.shared_experts, - gate=self.gate, - replace_allreduce=replace_allreduce) - - hidden_states = ( - experts_hidden_states[0] * self.routed_scaling_factor + - experts_hidden_states[1]) - if self.all_reduce_merge: - # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - return hidden_states - - class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): def __init__( @@ -508,23 +183,12 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - if (config.n_routed_experts is not None - and self.debug_layer_idx >= config.first_k_dense_replace - and self.debug_layer_idx % config.moe_layer_freq == 0 - and self.enable_shared_expert_dp): - self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - else: - self.o_proj = CustomDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' @@ -540,29 +204,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, + mla_modules = AscendMLAModules( q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, @@ -571,6 +213,28 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, o_proj=self.o_proj, + rotary_emb=self.rotary_emb, + ) + + self.mla_attn = MultiHeadLatentAttention( + self.hidden_size, + self.enable_shared_expert_dp, + self.debug_layer_idx, + self.first_k_dense_replace, + self.tp_size, + mla_modules, + self.num_local_heads, + self.scaling, + self.layers, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.q_lora_rank, + self.qk_nope_head_dim, + self.qk_head_dim, + self.v_head_dim, + cache_config, + quant_config, + prefix, ) def forward( @@ -579,43 +243,194 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if kv_cache is None: - kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] - num_tokens = hidden_states.shape[0] - need_gather_q_kv = False - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - # Simulate all gather to calculate output shape - num_tokens = num_tokens * self.tp_size - need_gather_q_kv = True - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: - output_shape = hidden_states.shape - else: - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) - output = torch.empty(output_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - output = self.mla_attn.impl.forward(hidden_states, kv_cache, - forward_context.attn_metadata, - need_gather_q_kv, output) - output = output.view(-1, output_shape[-1]) - return output + return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata) -class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): +class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention): def __init__( self, config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + sfa_modules = AscendSFAModules( + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + rotary_emb=self.rotary_emb, + indexer=self.indexer) + + self.sfa_attn = AscendSparseFlashAttention( + self.hidden_size, + self.enable_shared_expert_dp, + self.debug_layer_idx, + self.first_k_dense_replace, + self.tp_size, + sfa_modules, + self.num_local_heads, + self.scaling, + self.layers, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.q_lora_rank, + self.qk_nope_head_dim, + self.qk_head_dim, + self.v_head_dim, + cache_config, + quant_config, + prefix, + ) + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata) + + +class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + ascend_config = get_ascend_config() + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -628,10 +443,12 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.layers = config.num_hidden_layers self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group - ascend_config = get_ascend_config() # TODO: enable mla in vllm-ascend if model_config.use_mla: - attn_cls = CustomDeepseekV2MLAAttention + if ascend_config.use_sfa: + attn_cls = CustomDeepseekV2SFAAttention + else: + attn_cls = CustomDeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( @@ -655,13 +472,18 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = CustomDeepseekV2MoE( + self.mlp = DeepseekV2MoE( config=config, + parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) + if self.mlp.gate.e_score_correction_bias is not None: + self.mlp.gate.e_score_correction_bias.data = ( + self.mlp.gate.e_score_correction_bias.data.to( + dtype=torch.get_default_dtype())) else: - self.mlp = CustomDeepseekV2MLP( + self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -675,194 +497,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.routed_scaling_factor = config.routed_scaling_factor self.first_k_dense_replace = config.first_k_dense_replace self.tp_group = get_tp_group().device_group - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - replace_allreduce: bool = False, - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - previous_hidden_states, previous_residual = hidden_states, residual - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - # Dispose hidden_states and residual from the previous layer - # to save npu memory because they're no longer used. - dispose_tensor(previous_hidden_states) - dispose_tensor(previous_residual) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - if hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # We scale both hidden_states and residual before - # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor - if self.layer_idx == 0: - # The residual is shared by all layers, we only scale it on - # first layer. - residual *= 1. / self.routed_scaling_factor - - tp_size = get_tensor_model_parallel_world_size() - if self.enable_shared_expert_dp and ( - self.layer_idx == self.first_k_dense_replace - or self.layer_idx == self.layers) and tp_size > 1: - num_tokens, _ = residual.shape - if num_tokens % tp_size: - residual = nn.functional.pad(residual, - (0, 0, 0, -num_tokens % tp_size)) - chunk_residual = torch.tensor_split(residual, tp_size, dim=0) - tp_rank = get_tensor_model_parallel_rank() - residual = chunk_residual[tp_rank] - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if isinstance(self.mlp, CustomDeepseekV2MoE): - hidden_states = self.mlp(hidden_states, attn_metadata) - else: - hidden_states = self.mlp(hidden_states) - - if isinstance( - self.mlp, - CustomDeepseekV2MLP) and hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor - - # for last layer of main model and mtp layer. - if self.enable_shared_expert_dp and self.layer_idx >= ( - self.layers - 1) and tp_size > 1: - hidden_states = get_tp_group().all_gather(hidden_states, 0) - residual = get_tp_group().all_gather(residual, 0) - - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: - num_tokens = attn_metadata.num_actual_tokens - else: - num_tokens = hidden_states.shape[0] - - if num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:num_tokens] - residual = residual[:num_tokens] - - return hidden_states, residual - - -class CustomDeepseekV2Model(nn.Module): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.tp_size = get_tensor_model_parallel_world_size() - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: CustomDeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - ), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - residual, - kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata, - replace_allreduce=replace_allreduce) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): - # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) @@ -870,9 +507,21 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = CustomDeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = hasattr( + config, "q_lora_rank") and config.q_lora_rank is not None + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, @@ -882,9 +531,36 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights: list[Any] = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts # NOTE: This `load_weights` is mainly copied from # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 @@ -982,16 +658,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): loaded_params.add(name) return loaded_params - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states + +class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): + pass + + +DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__ diff --git a/vllm_ascend/models/deepseek_v3.py b/vllm_ascend/models/deepseek_v3.py index 4d09ef0..e69de29 100644 --- a/vllm_ascend/models/deepseek_v3.py +++ b/vllm_ascend/models/deepseek_v3.py @@ -1,27 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2ForCausalLM - - -class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): - pass diff --git a/vllm_ascend/models/layers/__init__.py b/vllm_ascend/models/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py new file mode 100644 index 0000000..57c91bd --- /dev/null +++ b/vllm_ascend/models/layers/mla.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.mla import MultiHeadLatentAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op + + +@dataclass +class AscendMLAModules: + q_a_proj: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: torch.nn.Module + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + o_proj: torch.nn.Module + rotary_emb: torch.nn.Module + + +class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): + + def __init__( + self, + hidden_size: int, + enable_shared_expert_dp: bool, + debug_layer_idx: int, + first_k_dense_replace: int, + tp_size: int, + mla_modules: AscendMLAModules, + num_local_heads: int, + scaling: float, + layers: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + q_lora_rank: Optional[int], + qk_nope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.enable_shared_expert_dp = enable_shared_expert_dp + self.debug_layer_idx = debug_layer_idx + self.first_k_dense_replace = first_k_dense_replace + self.tp_size = tp_size + self.num_local_heads = num_local_heads + self.layers = layers + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.prefix = prefix + + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=mla_modules.rotary_emb, + q_a_proj=mla_modules.q_a_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + kv_b_proj=mla_modules.kv_b_proj, + o_proj=mla_modules.o_proj, + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # Simulate all gather to calculate output shape + num_tokens = num_tokens * self.tp_size + need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + # FIXME: This does not seem right, should make sure the buffer is fixed + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, + self.prefix) + output = output.view(-1, output_shape[-1]) + return output + + +def mla_forward( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata + kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] + self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states, + kv_cache, attn_metadata, need_gather_q_kv, + output) + return + + +def mla_forward_fake( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mla_forward", + op_func=mla_forward, + mutates_args=["output"], + fake_impl=mla_forward_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/models/layers/sfa.py b/vllm_ascend/models/layers/sfa.py new file mode 100644 index 0000000..f68281c --- /dev/null +++ b/vllm_ascend/models/layers/sfa.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.mla import MultiHeadLatentAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op + + +@dataclass +class AscendSFAModules: + q_a_proj: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: torch.nn.Module + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + o_proj: torch.nn.Module + rotary_emb: torch.nn.Module + indexer: torch.nn.Module + + +class AscendSparseFlashAttention(MultiHeadLatentAttention): + + def __init__( + self, + hidden_size: int, + enable_shared_expert_dp: bool, + debug_layer_idx: int, + first_k_dense_replace: int, + tp_size: int, + sfa_modules: AscendSFAModules, + num_local_heads: int, + scaling: float, + layers: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + q_lora_rank: Optional[int], + qk_nope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.enable_shared_expert_dp = enable_shared_expert_dp + self.debug_layer_idx = debug_layer_idx + self.first_k_dense_replace = first_k_dense_replace + self.tp_size = tp_size + self.num_local_heads = num_local_heads + self.layers = layers + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.prefix = prefix + + self.sfa_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sfa=True, + # SFA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=sfa_modules.rotary_emb, + q_a_proj=sfa_modules.q_a_proj, + q_a_layernorm=sfa_modules.q_a_layernorm, + q_proj=sfa_modules.q_proj, + kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa, + kv_a_layernorm=sfa_modules.kv_a_layernorm, + kv_b_proj=sfa_modules.kv_b_proj, + o_proj=sfa_modules.o_proj, + indexer=sfa_modules.indexer) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # Simulate all gather to calculate output shape + num_tokens = num_tokens * self.tp_size + need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + # FIXME: This does not seem right, should make sure the buffer is fixed + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output, + self.prefix) + output = output.view(-1, output_shape[-1]) + return output + + +def sfa_forward( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata + kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine] + self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + return + + +class Indexer(nn.Module): + + def __init__(self, + config, + dim: int = 7168, + n_heads: int = 64, + head_dim: int = 128, + index_topk: int = 2048, + q_lora_rank: int = 1536, + rope_head_dim: int = 64, + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = ""): + super().__init__() + + self.dim: int = dim # 7168 + self.n_heads: int = n_heads # 64 + self.head_dim: int = head_dim # 128 + self.rope_head_dim: int = rope_head_dim # 64 + self.index_topk: int = index_topk # 2048 + self.q_lora_rank: int = q_lora_rank # 1536 + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + return_bias=False, + ) + self.wk = ReplicatedLinear( + self.dim, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + return_bias=False, + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.weights_proj", + return_bias=False, + ) + self.k_norm = nn.LayerNorm(self.head_dim) + self.softmax_scale = self.head_dim**-0.5 + + def forward(self): + return + + +def sfa_forward_fake( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="sfa_forward", + op_func=sfa_forward, + mutates_args=["output"], + fake_impl=sfa_forward_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py deleted file mode 100644 index 3e2148c..0000000 --- a/vllm_ascend/models/pangu_moe.py +++ /dev/null @@ -1,1106 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch_npu -from torch import nn -from torch.nn import Parameter -from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_tp_group, get_world_group) -from vllm.forward_context import get_forward_context -from vllm.logger import logger -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.models.utils import ( - extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors - -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p - -_ROUTER_SCALE = None - - -def use_h2p(): - # only use H2P when dp_size > 1. - if get_dp_group().world_size > 1: - return True - return False - - -# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear. -# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp). -class CustomMergedColumnParallelLinear(LinearBase): - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - ): - # Divide the weight matrix along the last dimension. - output_size = sum(output_sizes) - self.output_sizes = output_sizes - self.tp_size = get_tp_group().world_size - self.input_size_per_partition = input_size - self.output_size_per_partition = divide(output_size, self.tp_size) - self.output_partition_sizes = [self.output_size_per_partition] - # If QKV or MergedColumn, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = [ - divide(output_size, self.tp_size) - for output_size in self.output_sizes - ] - - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) - - self.gather_output = gather_output - - if output_sizes is None: - output_sizes = [output_size] - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=self.weight_loader) - if bias: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: int): - param_data = param.data - output_dim = getattr(param, "output_dim", None) - - assert loaded_shard_id < len(self.output_sizes) - - tp_rank = get_tp_group().rank_in_group - tp_size = get_tp_group().world_size - if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size - - is_sharded_weight = getattr(param, "is_sharded_weight", False) - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) - start_idx = tp_rank * shard_size - if not is_sharded_weight: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - else: - ignore_warning = getattr(param, "ignore_warning", False) - if not ignore_warning: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - def forward( - self, input_ - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - bias = self.bias if not self.skip_bias_add else None - - # Matrix multiply. - assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output - return output, output_bias - - -# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear. -# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp) -# and detach communication to enable customized communication algorithms(e.g., H2P). -class CustomRowParallelLinear(LinearBase): - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - group=None, - ): - # Divide the weight matrix along the first dimension. - self.group = group if group is not None else get_tp_group() - self.tp_rank = self.group.rank_in_group - self.tp_size = self.group.world_size - self.input_size_per_partition = divide(input_size, self.tp_size) - self.output_size_per_partition = output_size - self.output_partition_sizes = [output_size] - - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) - - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=self.weight_loader) - if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") - - if bias: - self.bias = Parameter( - torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = self.group.rank_in_group - input_dim = getattr(param, "input_dim", None) - is_sharded_weight = getattr(param, "is_sharded_weight", False) - is_sharded_weight = is_sharded_weight - - param_data = param.data - if input_dim is not None and not is_sharded_weight: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - def forward( - self, input_ - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - input_parallel = input_ - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output = self.quant_method.apply(self, input_parallel, bias=bias_) - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - -class PanguProMoEMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", - ) -> None: - super().__init__() - if not use_h2p(): - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - ) - else: - self.gate_up_proj = CustomMergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = CustomRowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - ) - - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -def topk_wrapper(num_voted_experts): - - def pangu_group8_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool = False, - num_expert_group: int = 0, - topk_group: int = 0, - global_num_experts: int = 0, - ): - scores = F.softmax(gating_output, dim=1) - num_tokens = scores.shape[0] - router_scale = _ROUTER_SCALE.squeeze( # type: ignore - ) - # TODO: support disable expert parallel - ep_size = get_ep_group().world_size - local_num_experts = global_num_experts // ep_size - local_num_group = topk // ep_size - experts_per_group = global_num_experts // topk - local_group_start = get_ep_group().rank_in_group * local_num_experts - local_group_end = (get_ep_group().rank_in_group + - 1) * local_num_experts - scores = F.softmax(gating_output, dim=1) - scores = scores[..., local_group_start:local_group_end] - - router_weights = router_scale[local_group_start:local_group_end] - - if num_voted_experts == 8: - # use original topk - topk_weights, topk_ids = torch.max(scores.view( - scores.shape[0], local_num_group, -1), - dim=-1) - bias = torch.arange(0, - local_num_experts, - experts_per_group, - device=scores.device, - dtype=torch.int32).unsqueeze(0) - topk_ids = topk_ids.to(torch.int32) + bias - - else: - group_expert_indices = torch.arange(experts_per_group, - dtype=torch.int32, - device=scores.device).view( - 1, 1, -1) - group_expert_offset = (torch.arange( - local_num_group, dtype=torch.int32, device=scores.device) * - experts_per_group).unsqueeze(0) - expert_index_range = torch.arange(experts_per_group, - dtype=torch.int32, - device=scores.device) - - scores_grouped = scores.view(num_tokens, local_num_group, - experts_per_group) - best_expert_idx = torch.argmax(scores_grouped, - dim=2) # (num_tokens, num_groups) - vote_mask = (best_expert_idx.unsqueeze(-1).to( - torch.int32) == group_expert_indices) - - expert_vote_freq = vote_mask.sum(dim=0) - - sorted_indices = torch.argsort(expert_vote_freq, - dim=1, - descending=True).to(torch.int32) - topk_experts = sorted_indices[:, :num_voted_experts] - keep_mask = (( - topk_experts.unsqueeze(-1) == expert_index_range).any( - dim=1)).unsqueeze(0) - - masked_scores = torch.where(keep_mask, scores_grouped, 0) - - topk_weights, best_pos_in_group = masked_scores.max(dim=2) - best_pos_in_group = best_pos_in_group.to(torch.int32) - topk_ids = (best_pos_in_group + group_expert_offset).to( - torch.int32) - - flatten_topk_ids = topk_ids.view(-1) - router_weights = router_weights.index_select(0, flatten_topk_ids).view( - topk_ids.shape) - topk_weights *= router_weights - - return topk_weights, topk_ids - - return pangu_group8_topk - - -class PanguProMoESparseMoeBlock(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_experts = config.num_experts - - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.num_experts_per_tok = config.num_experts_per_tok - self.router_scale = torch.nn.Parameter( - torch.ones((1, self.num_experts))) - - # on 300I Duo platform, we find that num_voted_experts set to 5 achieves - # good performance without sacrifice too much accuracy. for other platform, - # this is set to 8 to use original pangu grouped topk. - num_voted_experts = 5 if is_310p() else 8 - - self.experts = FusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - quant_config=quant_config, - custom_routing_function=topk_wrapper(num_voted_experts), - prefix=f"{prefix}.experts", - ) - self.use_ep = self.experts.use_ep - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - - if config.shared_expert_intermediate_size > 0: - self.shared_expert = PanguProMoEMLP( - hidden_size=config.hidden_size, - intermediate_size=config.shared_expert_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_expert", - ) - else: - self.shared_expert = None # type: ignore - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - global _ROUTER_SCALE - _ROUTER_SCALE = self.router_scale - - # TODO(angazenn): Does not support MC2 currently - get_forward_context().moe_comm_method_name = "allgathercommimpl" - - if not use_h2p(): - final_hidden_states = self.experts.forward_impl( - hidden_states=hidden_states, router_logits=router_logits) - else: - # TODO: when using h2p, we have to skip communication in vLLM - # native FusedMoE. here we need to design a better FusedMoE - # (maybe using AscendFusedMoE) to enable these different - # communication schema. - final_hidden_states = self.experts.quant_method.apply( - layer=self.experts, - x=hidden_states, - router_logits=router_logits, - top_k=self.experts.top_k, - renormalize=False, - use_grouped_topk=False, - global_num_experts=self.experts.global_num_experts, - expert_map=self.experts.expert_map, - custom_routing_function=self.experts.custom_routing_function, - apply_router_weight_on_input=self.experts. - apply_router_weight_on_input) - - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - if not use_h2p(): - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_dim) - - -class PanguProMoEAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - if use_h2p(): - self.o_proj = CustomRowParallelLinear(self.total_num_heads * - self.head_dim, - hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - group=get_tp_group()) - else: - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - - output, _ = self.o_proj(attn_output) - return output - - -class PanguProMoEDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - - self.self_attn = PanguProMoEAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - # `mlp_only_layers` in the config. - layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) - if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): - self.mlp = PanguProMoESparseMoeBlock( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = PanguProMoEMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - h2p_unpad_idx: Optional[torch.Tensor] = None, - h2p_pad_idx: Optional[torch.Tensor] = None, - is_start_layer: Optional[bool] = False, - ) -> torch.Tensor: - need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \ - and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0] - tp_size = get_tp_group().world_size - - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - if use_h2p(): - if is_start_layer: - if need_h2p_pad: - residual = residual.index_select(dim=0, index=h2p_pad_idx) - residual = torch.tensor_split( - residual, tp_size)[get_tp_group().rank_in_group] - else: - if tp_size > 1: - hidden_states = get_tp_group().all_gather(hidden_states, 0) - if need_h2p_pad: - hidden_states = hidden_states.index_select( - dim=0, index=h2p_unpad_idx) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - if use_h2p(): - if need_h2p_pad: - hidden_states = hidden_states.index_select(dim=0, - index=h2p_pad_idx) - if tp_size > 1: - hidden_states = dist._functional_collectives.reduce_scatter_tensor( - hidden_states, - "sum", - scatter_dim=0, - group=get_tp_group().device_group) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if use_h2p(): - all_rank_group = get_world_group().device_group - output_size = (hidden_states.shape[0] * - get_world_group().world_size, - hidden_states.shape[1]) - # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=hidden_states.dtype, - device=hidden_states.device) - # All-gather. - dist.all_gather_into_tensor(output_tensor, - hidden_states, - group=all_rank_group) - hidden_states = output_tensor - - hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata) - - if use_h2p(): - hidden_states = dist._functional_collectives.reduce_scatter_tensor( - hidden_states, - "sum", - scatter_dim=0, - group=get_world_group().device_group) - - return hidden_states, residual - - -@support_torch_compile -class PanguProMoEModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: PanguProMoEDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - if use_h2p(): - # calculate necessary padding/unpadding idx before model forward. - - # the attn_metadata will be passed directly when use torchair. - # if attn_meatadata is not passed, we try to get it from forward_context. - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - - max_tokens_across_dp = get_forward_context().max_tokens_across_dp - - tp_size = get_tp_group().world_size - # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. - # we need pad it before if the shape can't be divided by group size. - # for h2p, we need pad it so that it can be divided by tp_size. - h2p_padded_len = ( - tp_size - (max_tokens_across_dp % tp_size) - ) % tp_size + max_tokens_across_dp - hidden_states.shape[0] - h2p_unpad_idx = torch.arange(hidden_states.shape[0], - device=hidden_states.device, - dtype=torch.int32) - h2p_pad_idx = torch.cat([ - h2p_unpad_idx, - torch.zeros(h2p_padded_len, - dtype=torch.int32, - device=hidden_states.device) - ]) - else: - h2p_unpad_idx = None - h2p_pad_idx = None - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, residual, - kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata, h2p_unpad_idx, h2p_pad_idx, - i == self.start_layer) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - if use_h2p(): - if get_tp_group().world_size > 1: - hidden_states = get_tp_group().all_gather(hidden_states, 0) - if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]: - hidden_states = hidden_states.index_select(dim=0, - index=h2p_unpad_idx) - return hidden_states - - -class PanguProMoEForCausalLM(nn.Module, SupportsPP): - - fall_back_to_pt_during_load = False - - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = PanguProMoEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.lm_head", - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - tp_size = get_tp_group().world_size - tp_rank = get_tp_group().rank_in_group - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - - # expert_params_mapping = [] - - params_dict = dict(self.named_parameters()) # from model - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - # ======================================================= - # BF: add this to load with less layers - if 'layers' in name: - layer_idx = int(name.split('layers.')[-1].split('.')[0]) - if layer_idx >= self.model.end_layer: - continue - - if "rotary_emb.inv_freq" in name: - continue - - if "module" in name: - continue - - if name.endswith('kv_cache_offset'): - continue - - if name.endswith("k_proj.kv_cache_scale"): - remapped_kv_scale_name = name.replace( - "k_proj.kv_cache_scale", "attn.key_antiquant_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - loaded_weight = torch.tensor_split(loaded_weight, - tp_size, - dim=0)[tp_rank] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - if name.endswith("v_proj.kv_cache_scale"): - remapped_kv_scale_name = name.replace( - "v_proj.kv_cache_scale", "attn.value_antiquant_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - loaded_weight = torch.tensor_split(loaded_weight, - tp_size, - dim=0)[tp_rank] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if "mlp.experts" in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - # breakpoint() - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = param.weight_loader - # breakpoint() - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - if is_310p() and "head" in name: - # on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than - # ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented - # by linear, we manually cast the format here. - param.data = torch_npu.npu_format_cast(param.data, - ACL_FORMAT_FRACTAL_NZ) - return loaded_params diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 31ad260..f240fd1 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -42,6 +42,8 @@ from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm_ascend.utils import vllm_version_is + MIN_PAD_SIZE = 64 # min_size to pad weight MAX_PAD_SIZE = 128 # max_size to pad weight @@ -291,6 +293,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): self.hidden_size, -1) return out_weight + def pad_qkv_weight_scale_offset(self, data): + reshaped_data = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, 1) + data1 = reshaped_data[:, :, :self. + half_origin_hidden_size_per_attention_head, :] + data2 = reshaped_data[:, :, self. + half_origin_hidden_size_per_attention_head:, :] + data1_paded = torch.nn.functional.pad( + data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, + 0, 0, 0)) + data2_paded = torch.nn.functional.pad( + data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, + 0, 0, 0)) + res = torch.cat([data1_paded, data2_paded], dim=2) + res = res.reshape(-1, 1) + return res + + def pad_qkv_deq_scale_quant_bias(self, data): + reshaped_data = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head) + data1 = reshaped_data[:, :, :self. + half_origin_hidden_size_per_attention_head] + data2 = reshaped_data[:, :, + self.half_origin_hidden_size_per_attention_head:] + + data1_paded = torch.nn.functional.pad( + data1, (0, self.half_pad_hidden_size_per_attention_head)) + data2_paded = torch.nn.functional.pad( + data2, (0, self.half_pad_hidden_size_per_attention_head)) + + res = torch.cat([data1_paded, data2_paded], dim=2) + res = res.reshape(-1) + return res + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ @@ -318,11 +354,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ("attn.proj.weight" in name) and self.enable_pad: + if ("attn.proj.weight_scale" in name or + "attn.proj.weight_offset" in name) and self.enable_pad: + continue + elif ("attn.proj.deq_scale" in name + or "attn.proj.quant_bias" in name) and self.enable_pad: + continue + elif ("attn.qkv.weight_scale" in name + or "attn.qkv.weight_offset" in name) and self.enable_pad: + param.data = self.pad_qkv_weight_scale_offset(param.data) + elif ("attn.qkv.deq_scale" in name + or "attn.qkv.quant_bias" in name) and self.enable_pad: + param.data = self.pad_qkv_deq_scale_quant_bias(param.data) + elif ("attn.proj.weight" in name) and self.enable_pad: param.data = self.pad_proj_weight(param.data) - if ("attn.qkv.weight" in name) and self.enable_pad: + elif ("attn.qkv.weight" in name) and self.enable_pad: param.data = self.pad_qkv_weight(param.data) - if ("attn.qkv.bias" in name) and self.enable_pad: + elif ("attn.qkv.bias" in name) and self.enable_pad: param.data = self.pad_qkv_bias(param.data) loaded_params.add(name) return loaded_params @@ -450,12 +498,20 @@ class AscendQwen2_5_VLForConditionalGeneration( super().__init__(vllm_config=vllm_config, prefix=prefix) config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.visual = AscendQwen2_5_VisionTransformer( - vision_config=config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if vllm_version_is("0.10.2"): + self.visual = AscendQwen2_5_VisionTransformer( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = AscendQwen2_5_VisionTransformer( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: diff --git a/vllm_ascend/models/qwen2_5_vl_without_padding.py b/vllm_ascend/models/qwen2_5_vl_without_padding.py index 5a243e0..f62009b 100644 --- a/vllm_ascend/models/qwen2_5_vl_without_padding.py +++ b/vllm_ascend/models/qwen2_5_vl_without_padding.py @@ -1,6 +1,5 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Adapted from vllm/model_executor/models/qwen2_5_vl.py # Copyright 2023 The vLLM team. # # This file is a part of the vllm-ascend project. @@ -27,10 +26,19 @@ import torch_npu from einops import rearrange from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) + +try: + from transformers.models.qwen3_vl.configuration_qwen3_vl import \ + Qwen3VLConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ + Qwen3VLMoeConfig +except ImportError: + pass from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, + get_act_and_mul_fn) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen2_5_vl import ( @@ -38,10 +46,29 @@ from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo) -from vllm.model_executor.models.utils import maybe_prefix + +try: + from vllm.model_executor.models.qwen3_vl import ( + Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) + from vllm.model_executor.models.qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) +except ImportError: + Qwen3_VisionBlock = object + Qwen3_VisionPatchEmbed = object + Qwen3_VisionTransformer = object + Qwen3VLDummyInputsBuilder = object + Qwen3VLForConditionalGeneration = object + Qwen3VLMultiModalProcessor = object + Qwen3VLProcessingInfo = object + Qwen3VLMoeForConditionalGeneration = object + Qwen3VLMoeProcessingInfo = object +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding +from vllm_ascend.utils import vllm_version_is class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): @@ -112,16 +139,14 @@ class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock): - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, quant_config, prefix) self.attn = AscendQwen2_5_VisionAttention_Without_Padding( @@ -321,6 +346,133 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer return x +class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) + x = x + self.proj.bias + return x + + +class AscendQwen3_VisionBlock(Qwen3_VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, + quant_config, prefix, use_data_parallel) + self.attn = AscendQwen2_5_VisionAttention_Without_Padding( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix, + use_data_parallel) + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + self.patch_embed = AscendQwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + self.blocks = nn.ModuleList([ + AscendQwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + grid_thw_tensor = torch.tensor(grid_thw, + device=self.device, + dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], + grid_thw_tensor[:, 0]).cpu().to(torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + cos=cos, + sin=sin) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + @MULTIMODAL_REGISTRY.register_processor( Qwen2_5_VLMultiModalProcessor, info=Qwen2_5_VLProcessingInfo, @@ -332,12 +484,20 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding( super().__init__(vllm_config=vllm_config, prefix=prefix) config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( - vision_config=config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if vllm_version_is("0.10.2"): + self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: @@ -371,3 +531,101 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding( merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size return video_embeds.split(sizes.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen3VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + if vllm_version_is("0.10.2"): + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel) + else: + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class AscendQwen3VLMoeForConditionalGeneration( + Qwen3VLMoeForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + if vllm_version_is("0.10.2"): + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + else: + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index a677b06..9648e07 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -40,6 +40,8 @@ from vllm.model_executor.models.qwen2_vl import ( from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm_ascend.utils import vllm_version_is + MIN_PAD_SIZE = 64 # min_size to pad weight MAX_PAD_SIZE = 128 # max_size to pad weight @@ -343,10 +345,18 @@ class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.visual = AscendQwen2VisionTransformer( - self.config.vision_config, - norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config( - vllm_config.quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) \ No newline at end of file + if vllm_version_is("0.10.2"): + self.visual = AscendQwen2VisionTransformer( + self.config.vision_config, + norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config( + vllm_config.quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = AscendQwen2VisionTransformer( + self.config.vision_config, + norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py deleted file mode 100644 index a05106f..0000000 --- a/vllm_ascend/models/qwen3.py +++ /dev/null @@ -1,156 +0,0 @@ -from collections.abc import Iterable -from typing import Optional, Union - -import torch -from torch import nn -from transformers import Qwen3Config -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer -from vllm.model_executor.models.utils import (AutoWeightsLoader, - PPMissingLayer, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant - - -class CustomQwen3DecoderLayer(Qwen3DecoderLayer): - - def __init__( - self, - config: Qwen3Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) - if quant_config is None: - return - - from vllm_ascend.quantization.quant_config import AscendQuantConfig - from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod - - assert isinstance(quant_config, AscendQuantConfig), \ - "Expected quant_config to be an instance of AscendQuantConfig" - - if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, - AscendW8A8LinearMethod): - self.input_layernorm = AddRMSNormW8A8Quant( - config.hidden_size, - layer=self.self_attn.qkv_proj, - eps=config.rms_norm_eps) - if isinstance(self.mlp.gate_up_proj.quant_method.quant_method, - AscendW8A8LinearMethod): - self.post_attention_layernorm = AddRMSNormW8A8Quant( - config.hidden_size, - layer=self.mlp.gate_up_proj, - eps=config.rms_norm_eps) - - -ALL_DECODER_LAYER_TYPES = { - "attention": CustomQwen3DecoderLayer, -} - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }) -class CustomQwen3Model(Qwen2Model): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - decoder_layer_type=CustomQwen3DecoderLayer) - - -class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - # add `CustomQwen3Model` to init self.model - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = CustomQwen3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 2fa10f0..bc0a04e 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -17,14 +17,14 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. -from typing import Optional, Union +from typing import Optional import torch from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, CompilationLevel, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, from vllm.model_executor.models.utils import ( PPMissingLayer, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.sequence import IntermediateTensors from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, - init_metadata_for_sp) from vllm_ascend.utils import vllm_version_is @@ -101,7 +98,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): self, hidden_states, attn_metadata=None, - _metadata_for_padding: Optional[MetadataForPadding] = None, ): if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata @@ -120,7 +116,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): top_k=self.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=None, - _metadata_for_padding=_metadata_for_padding, ) return hidden_states @@ -175,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp") else: - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if vllm_version_is("0.10.2"): + self.mlp = Qwen3MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -189,60 +189,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.enable_sequence_parallelism = ( - vllm_config.compilation_config.pass_config. - enable_sequence_parallelism if vllm_config is not None else False) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - _metadata_for_padding: Optional[MetadataForPadding] = None, - ) -> torch.Tensor: - - # To prevent precision issues during the decoder phase when only prefilling enables SP - if not self.enable_sequence_parallelism: - self.self_attn.o_proj.reduce_results = True - else: - self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True - - # Self Attention - if residual is None: - residual = hidden_states - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - residual = _metadata_for_padding.padding_slice(residual) - - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.allgather_unpadding_aligned( - hidden_states) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter( - hidden_states) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if not self.use_aclgraph: - hidden_states = self.mlp( - hidden_states, _metadata_for_padding=_metadata_for_padding) - else: - hidden_states = self.mlp(hidden_states) - - return hidden_states, residual - @support_torch_compile class CustomQwen3MoeModel(Qwen3MoeModel): @@ -254,11 +200,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - self.num_redundant_experts = parallel_config.num_redundant_experts - else: - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config @@ -281,60 +224,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - _metadata_for_padding: Optional[MetadataForPadding] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - residual, - _metadata_for_padding=_metadata_for_padding) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.allgather_unpadding_aligned( - hidden_states) - - return hidden_states - class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) @@ -357,7 +248,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism # Set MoE hyperparameters self.expert_weights: list[torch.Tensor] = [] @@ -378,16 +268,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - _metadata_for_padding = init_metadata_for_sp( - input_ids, self.enable_sequence_parallelism) - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, _metadata_for_padding) - return hidden_states diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py new file mode 100644 index 0000000..47b6d3e --- /dev/null +++ b/vllm_ascend/models/qwen3_next.py @@ -0,0 +1,676 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# mypy: ignore-errors +"""Inference-only Qwen3Next model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN +from vllm import envs +from vllm.attention import AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, + VllmConfig, get_current_vllm_config) +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fla.ops import RMSNormGated +from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops.fused_recurrent import \ + fused_recurrent_gated_delta_rule +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.layernorm import \ + GemmaRMSNorm as Qwen3NextRMSNorm +# yapf: enable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import \ + mamba_v2_sharded_weight_loader +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.utils import set_weight_attrs +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from vllm.model_executor.models.qwen3_next import ( # isort: skip + Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, + Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, + fused_gdn_gating) + + +class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, + self.head_v_dim, self.conv_kernel_size, self.num_spec) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = (self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.projection_size_qkvz, self.projection_size_ba], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + query_key_settings, + query_key_settings, + value_settings, + ], self.tp_size, self.tp_rank) + }) + + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + )) + + set_weight_attrs(self.A_log, + {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + norm_before_gate=True, + device="npu", + ) + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj") + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + num_accepted_tokens = attn_metadata.num_accepted_tokens + + # 1. Set up dimensions for reshapes later + projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + projected_states_qkvz, projected_states_ba = torch.split( + projected_states, + [ + self.projection_size_qkvz // self.tp_size, + self.projection_size_ba // self.tp_size + ], + dim=-1, + ) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + # validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + + batch_size = initial_state.shape[0] + core_attn_out = [] + last_recurrent_state = [] + + for b_idx in range(batch_size): + start, end = non_spec_query_start_loc[ + b_idx], non_spec_query_start_loc[b_idx + 1] + cur_q = query_non_spec[:, start:end, ...] + cur_k = key_non_spec[:, start:end, ...] + cur_v = value_non_spec[:, start:end, ...] + cur_g = g_non_spec[:, start:end, ...] + cur_b = beta_non_spec[:, start:end, ...] + cur_state = initial_state[b_idx].unsqueeze(0) + + ( + cur_core_attn_out_non_spec, + cur_last_recurrent_state, + ) = chunk_gated_delta_rule( + query=cur_q, + key=cur_k, + value=cur_v, + g=cur_g, + beta=cur_b, + initial_state=cur_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out.append(cur_core_attn_out_non_spec) + last_recurrent_state.append(cur_last_recurrent_state) + + tar_dtype = core_attn_out[0].dtype + tar_device = core_attn_out[0].device + tar_shape = list(core_attn_out[0].shape) + tar_shape[1] = non_spec_query_start_loc[-1] + core_attn_out_non_spec = torch.empty(tar_shape, + dtype=tar_dtype, + device=tar_device) + for b_idx in range(batch_size): + cur_core_attn_out = core_attn_out[b_idx] + start, end = non_spec_query_start_loc[ + b_idx], non_spec_query_start_loc[b_idx + 1] + core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out + last_recurrent_state = torch.cat(last_recurrent_state, dim=0) + + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): + + def __init__( + self, + vllm_config: VllmConfig, + layer_type: str, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = CustomQwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f'{prefix}.linear_attn') + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.self_attn', + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (self.layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.torch_dtype, + ), ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.torch_dtype, + ), ) + + +@support_torch_compile +class CustomQwen3NextModel(Qwen3NextModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config: Qwen3NextConfig = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return CustomQwen3NextDecoderLayer( + vllm_config, + layer_type=config.layer_types[extract_layer_index(prefix)], + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("in_proj", "in_proj_qkvz", 0), + ("in_proj", "in_proj_ba", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3Next currently does not support prefix caching" + assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" + self.quant_config = vllm_config.quant_config + self.config = config + self.scheduler_config = scheduler_config + self.model = CustomQwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index a1e7417..381c1b6 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -20,6 +20,7 @@ import torch import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa +import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( @@ -34,19 +35,20 @@ class dummyFusionOp: def register_dummy_fusion_op() -> None: - torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm") - torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm") - torch.ops._C.static_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm") + torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp( + name="fused_add_rms_norm") + torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp( name="static_scaled_fp8_quant") - torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp( name="dynamic_scaled_fp8_quant") - torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( name="dynamic_per_token_scaled_fp8_quant") - torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp( name="rms_norm_static_fp8_quant") - torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( name="fused_add_rms_norm_static_fp8_quant") - torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp( name="rms_norm_dynamic_per_token_quant") diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index 26082fe..fb1abe6 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul): from vllm_ascend.utils import is_310p + torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) if is_310p(): out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) else: out = torch_npu.npu_swiglu(x) + torch.ops.vllm.maybe_wait_prefetch_done(out) return out diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/casual_conv1d.py new file mode 100644 index 0000000..2d00889 --- /dev/null +++ b/vllm_ascend/ops/casual_conv1d.py @@ -0,0 +1,539 @@ +# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# mypy: ignore-errors + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]].unsqueeze(0), + initial_states=conv_states[cache_indices[i]] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim + ) # [BLOCK_N] + + x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = (conv_state_base + + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + # mask_1d = (idx_token < seqlen) & ( + # idx_feats < dim + # ) # token-index # feature-index + maskL = idx_feats < dim + maskR = tl.full(maskL.shape, False, tl.int1) + mask_1d = tl.where(idx_token < seqlen, maskL, maskR) + + o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + (idx_feats * stride_o_dim)) + + tl.store(o_ptrs, acc, mask=mask_1d) + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = (intermediate_conv_window_ptr + + conv_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + +def causal_conv1d_update_npu( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch, ) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride( + ) # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + ) + stride_state_indices = (conv_state_indices.stride(0) + if conv_state_indices is not None else 0) + state_len = width - 1 + (seqlen - 1) # effective state_len needed + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window + if intermediate_conv_window is not None else x, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=128, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 607991c..ac22b69 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -14,212 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import Any, Callable, Optional +import os.path +from typing import Callable, Optional import torch import torch_npu from vllm.config import CompilationLevel, get_current_vllm_config -from vllm.distributed import get_dp_group, get_ep_group, get_tp_group +from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, + tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.config import \ - FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod) + FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, - MC2CommImpl) +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - setup_token_dispatchers -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is +from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, + determine_default_log2phy_map) +from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - global_num_experts: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - # For TorchAir graph - is_torchair: bool = False, - # For Cube/Vector parallel - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - # For load balance - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, -) -> torch.Tensor: - # Check constraints - assert hidden_states.shape[1] == w1.shape[1], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - if (use_int8_w8a8 or use_int4_w4a8): - assert w1_scale is not None and w2_scale is not None, \ - "INT8 quantization requires weight scales." - - w1_scale = w1_scale.to(torch.float32) - down_scale = [w2_scale] - down_output_dtype = w2_scale.dtype - else: - down_scale = None - down_output_dtype = None - - moe_comm_method = get_forward_context().moe_comm_method - assert moe_comm_method is not None, "Missing communication context" - - num_experts = w1.shape[0] - - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute( - hidden_states, topk_ids, topk_weights, expert_map, num_experts, - use_int8_w8a8 or use_int4_w4a8) - - gate_up_output = torch_npu.npu_grouped_matmul( - x=[permuted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=expert_tokens, - output_dtype=torch.int32 if use_int8_w8a8 else None, - )[0] - - if (use_int8_w8a8 or use_int4_w4a8): - activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant( - x=gate_up_output, - weight_scale=w1_scale, - activation_scale=dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=expert_tokens, - activate_left=True, - quant_mode=1, - ) - activated_output_scale = [activated_output_scale] - else: - activated_output = torch_npu.npu_swiglu(gate_up_output) - activated_output_scale = None - - down_output = torch_npu.npu_grouped_matmul( - x=[activated_output], - weight=[w2], - scale=down_scale, - per_token_scale=activated_output_scale, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=expert_tokens, - output_dtype=down_output_dtype, - )[0] - - moe_comm_method.unpermute(down_output, hidden_states) - - return hidden_states - - -def fused_experts_moge( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - moe_parallel_config: FusedMoEParallelConfig, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - ep_size = moe_parallel_config.ep_size - local_num_experts = global_num_experts // ep_size - local_num_group = top_k // ep_size - - bsz, _ = hidden_states.shape - flatten_topk_ids = topk_ids.view(-1) - sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) - sorted_topk_ids = sorted_topk_ids.to(torch.int32) - sorted_hidden_states = hidden_states.index_select( - 0, sorted_topk_ids // local_num_group) - - experts_id = torch.arange(0, - local_num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( - torch.float32).sum(0) - topk_scales = topk_weights.view(-1).index_select( - 0, sorted_topk_ids).unsqueeze(-1) - group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - - gate_up_out = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - if is_310p(): - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) - gate_up_out *= topk_scales - - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) - unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) - final_hidden_states = unsorted_hidden_states.reshape( - bsz, top_k // ep_size, -1).sum(1) - - return final_hidden_states - - def unquantized_fused_moe_init_func(self, *args, **kwargs): original_unquantized_fused_moe_init_func(self, *args, **kwargs) @@ -235,67 +55,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs): self.use_aclgraph = (vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager) - - -def forward_oot_v01011( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - - topk_weights, topk_ids, _ = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=1.0, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) - - if topk_ids.shape[1] < top_k or is_310p(): - assert global_num_experts is not None - return fused_experts_moge( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - moe_parallel_config=self.moe.moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) + self.transpose = True def forward_oot( @@ -321,7 +81,7 @@ def forward_oot( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - topk_weights, topk_ids, _ = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -335,40 +95,35 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) - if topk_ids.shape[1] < top_k or is_310p(): - assert global_num_experts is not None - return fused_experts_moge( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - moe_parallel_config=self.moe.moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + global_num_experts=global_num_experts, + expert_map=expert_map) def process_weights_after_loading(self, layer): super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) - w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( - 1, 2).contiguous() - layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + if self.transpose: + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( + 1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) - w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( - 1, 2).contiguous() - layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( + 1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + self.transpose = False + else: + w13_data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data) + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) if not is_310p(): layer.w13_weight.data = torch_npu.npu_format_cast( @@ -378,119 +133,88 @@ def process_weights_after_loading(self, layer): class AscendFusedMoE(FusedMoE): + moe_counter = -1 - def __init__( - self, - num_experts, - top_k, - hidden_size, - intermediate_size, - params_dtype=None, - reduce_results=False, - renormalize=True, - use_grouped_topk=False, - num_expert_group=None, - topk_group=None, - quant_config=None, - tp_size=None, - ep_size=None, - dp_size=None, - prefix="", - custom_routing_function=None, - scoring_func="softmax", - routed_scaling_fator: float = 1.0, - e_score_correction_bias=None, - apply_router_weight_on_input=False, - activation="silu", - enable_eplb=False, - num_redundant_experts=0, - has_bias=False, - ): - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - super().__init__( - num_experts, - top_k, - hidden_size, - intermediate_size, - params_dtype, - reduce_results, - renormalize, - use_grouped_topk, - num_expert_group, - topk_group, - quant_config, - tp_size, - ep_size, - dp_size, - prefix, - custom_routing_function, - scoring_func, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - num_redundant_experts, - has_bias, - ) - else: - super().__init__( - num_experts, - top_k, - hidden_size, - intermediate_size, - params_dtype, - reduce_results, - renormalize, - use_grouped_topk, - num_expert_group, - topk_group, - quant_config, - tp_size, - ep_size, - dp_size, - prefix, - custom_routing_function, - scoring_func, - routed_scaling_fator, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - num_redundant_experts, - has_bias, - ) - - setup_token_dispatchers(self.moe_config.ep_size, - top_k=self.top_k, - num_experts=self.global_num_experts, - num_local_experts=self.local_num_experts) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + AscendFusedMoE.moe_counter += 1 + self.moe_instance_id = AscendFusedMoE.moe_counter self.moe_config.tp_group = get_tp_group() self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb + self.expert_map_path = ascend_config.expert_map_path + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + # static eplb initializing with expert_map_path + if self.expert_map_path and os.path.exists( + self.expert_map_path) and os.access(self.expert_map_path, + os.R_OK): + self.expert_load_balancer = ExpertLoadBalancer( + self.expert_map_path, self.global_num_experts) + self.local_num_experts, self.expert_map = ( + self.expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank)) + self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, self.ep_rank).npu() + self.global_redundant_expert_num = ( + self.expert_load_balancer.get_global_redundant_expert_num()) + else: + # init moe. + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) + # dynamic eplb initializing with not expert_map_path + if self.dynamic_eplb: + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.local_num_experts, self.expert_map = determine_default_expert_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + local_num_experts = (torch.sum( + self.expert_map != -1) if self.expert_map is not None else + self.global_num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) - for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}: - setattr( - self, method.__name__.lower(), - method(moe_config=self.moe_config)) # type: ignore[abstract] + setup_moe_comm_method(self.moe_config) + + def update_expert_map(self, new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.logical_to_physical_map + + def clear_moe_load(self): + if self.moe_load is not None: + self.moe_load.zero_() + + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): + """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, + and `alltoallcommimpl`, we do not need to all-reduce the final outputs since + the outputs are already aggregated across tensor parallel ranks in the + `finalize` function. In `allgathercommimpl`, we still need to all-reduce the + outputs since each rank only has partial outputs. + """ + return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None forward_context = get_forward_context() - moe_comm_method_name = forward_context.moe_comm_method_name - - # TODO: Can we refactor this logic to model_runner? - # TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now - if self.moe_config.ep_size < 16: - moe_comm_method_name = "allgathercommimpl" - - forward_context.moe_comm_method = getattr(self, moe_comm_method_name) - hidden_states, router_logits = forward_context.moe_comm_method.prepare( - hidden_states=hidden_states, router_logits=router_logits) + hidden_states=hidden_states, + router_logits=router_logits, + replace_allreduce=forward_context.sp_enabled) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -514,6 +238,12 @@ class AscendFusedMoE(FusedMoE): logical_to_physical_map=self.logical_to_physical_map, logical_replica_count=self.logical_replica_count, ) + if isinstance(final_hidden_states, tuple): + final_hidden_states, group_list_type, expert_tokens = final_hidden_states + + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) final_hidden_states = forward_context.moe_comm_method.finalize( hidden_states=final_hidden_states, @@ -521,11 +251,118 @@ class AscendFusedMoE(FusedMoE): return final_hidden_states + def transpose_weight(self, loaded_weight, expert_data, shard_dim): + # Ensure training and inference weight shapes match during RL weight updates + if ( + loaded_weight.shape[1] != expert_data.shape[1] and \ + loaded_weight.shape[0] != expert_data.shape[0] + ): + shard_dim = int(not shard_dim) + loaded_weight = loaded_weight.transpose(0, 1).contiguous() + return loaded_weight, shard_dim + + def _load_w13(self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + loaded_weight, shard_dim = self.transpose_weight( + loaded_weight, expert_data, shard_dim) + shard_size = expert_data.shape[shard_dim] // 2 + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + loaded_weight, shard_dim = self.transpose_weight( + loaded_weight, expert_data, shard_dim) + shard_size = expert_data.shape[shard_dim] + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + +class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + AscendFusedMoE.__init__(self, **kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + self.shared_expert_stream = None + ascend_config = get_ascend_config() + self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert + if self.multistream_overlap_shared_expert: + self.shared_expert_stream = torch.npu.Stream() + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + shared_out, fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + # Make sure the shared experts stream begins after hidden_states are ready. + if self.multistream_overlap_shared_expert: + self.shared_expert_stream.wait_stream( # type: ignore + torch.npu.current_stream()) + with npu_stream_switch(self.shared_expert_stream, + enabled=self.multistream_overlap_shared_expert): + # Use a separate stream to run shared experts. + shared_out = self._shared_experts(hidden_states) + + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + shared_out = tensor_model_parallel_all_reduce(shared_out) + fused_output = AscendFusedMoE.forward_impl( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + # Make sure the default stream waits for the shared experts stream to finish. + if self.multistream_overlap_shared_expert: + torch.npu.current_stream().wait_stream(self.shared_expert_stream) + return shared_out, fused_output + UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading - -if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011 -else: - UnquantizedFusedMoEMethod.forward_oot = forward_oot +UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py new file mode 100644 index 0000000..b200c67 --- /dev/null +++ b/vllm_ascend/ops/fla.py @@ -0,0 +1,218 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. +# mypy: ignore-errors + +import torch +import torch.nn.functional as F +import triton +from vllm.model_executor.layers.fla.ops.layernorm_guard import \ + layer_norm_fwd_kernel + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N, ) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N, ) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + if not is_rms_norm else None) + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.npu.device(x.device.index): + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + return y.reshape(x_shape_og) + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - num_heads % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_heads = num_heads + pad_size + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - + g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -( + (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = (torch.zeros(batch_size, sequence_length, + k_head_dim, v_head_dim).to(value) if + initial_state is None else initial_state.to(value)) + + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=1) + + # for each chunk + for i in range(0, tot_heads // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * + decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * + (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2) @ v_new) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1, + core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = core_attn_out.transpose(1, + 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 14396c1..97489f9 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -19,13 +19,9 @@ import os from typing import Any, Callable, Optional import torch -import torch.distributed as dist import torch_npu -from torch import nn from vllm.config import get_current_vllm_config -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -39,70 +35,16 @@ from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.communication_op import \ - data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, + determine_default_log2phy_map) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp -from vllm_ascend.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor, +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, get_all_reduce_merge_state, - get_rm_router_logits_state, is_310p) - - -def unified_fused_experts_eager(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - w1_scale: Optional[torch.Tensor] = None, - w1_scale_bias: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w2_scale_bias: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False): - token_dispatcher = get_forward_context().token_dispatcher - - results = token_dispatcher.token_dispatch( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - expert_map=expert_map, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - mc2_mask=mc2_mask, - apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=with_quant) - - expert_output = unified_apply_mlp( - hidden_states=results["hidden_states"], - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=results["group_list"], - dynamic_scale=results.get("dynamic_scale"), - group_list_type=results.get("group_list_type"), - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - topk_scales=results.get("topk_scales"), - with_quant=with_quant) - final_hidden_states = token_dispatcher.token_combine(expert_output) - return final_hidden_states + get_rm_router_logits_state, is_310p, + vllm_version_is) class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): @@ -115,6 +57,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_model_len = vllm_config.model_config.max_model_len get_ascend_config() + self.dynamic_eplb = get_ascend_config().dynamic_eplb try: device_group = get_mc2_group().device_group @@ -182,17 +125,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): if enable_force_load_balance and not self.use_aclgraph: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - return unified_fused_experts_eager(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - expert_map=expert_map, - shared_experts=shared_experts, - mc2_mask=kwargs.get( - "mc2_mask", None), - with_quant=False) + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + global_num_experts=global_num_experts, + expert_map=expert_map, + shared_experts=shared_experts, + need_trans=True, + dynamic_eplb=self.dynamic_eplb) class AscendFusedMoE(FusedMoE): @@ -290,42 +235,67 @@ class AscendFusedMoE(FusedMoE): self.moe_parallel_config.ep_size, is_deepseek_v3_r1) ascend_config = get_ascend_config() - expert_map_path = ascend_config.expert_map_path - if expert_map_path and os.path.exists(expert_map_path): - # moe expert load balance - expert_load_balancer = ExpertLoadBalancer(expert_map_path, - self.global_num_experts) - self.local_num_experts, self.expert_map = \ - expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, - get_ep_group().rank_in_group) - self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, - get_ep_group().rank_in_group) - self.global_redundant_expert_num = \ - expert_load_balancer.get_global_redundant_expert_num() + self.dynamic_eplb = ascend_config.dynamic_eplb + self.expert_map_path = ascend_config.expert_map_path + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.global_num_experts = num_experts + self.global_redundant_expert_num + # static eplb initializing with expert_map_path + if self.expert_map_path and os.path.exists( + self.expert_map_path) and os.access(self.expert_map_path, + os.R_OK): + self.expert_load_balancer = ExpertLoadBalancer( + self.expert_map_path, self.global_num_experts) + self.local_num_experts, self.expert_map = ( + self.expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank)) + self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, self.ep_rank).npu() + self.global_redundant_expert_num = ( + self.expert_load_balancer.get_global_redundant_expert_num()) else: - # Create a tensor of size num_experts filled with -1 + # init moe. self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + self.ep_size, self.ep_rank, self.global_num_experts) + # dynamic eplb initializing with not expert_map_path + if self.dynamic_eplb: + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.local_num_experts, self.expert_map = determine_default_expert_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - moe = FusedMoEConfig.make( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. - in_dtype=params_dtype, - quant_config=quant_config) - + if vllm_version_is("0.10.2"): + moe = FusedMoEConfig.make( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + # TODO (bnell): this needs to be fixed for quantized types. + in_dtype=params_dtype, + quant_config=quant_config) + else: + moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, + ) self.moe_config = moe + # TODO: The self.moe_config.tp_size here is not correct, fixme soon if quant_config is None: self.quant_method = AscendUnquantizedFusedMoEMethod(moe) @@ -337,6 +307,11 @@ class AscendFusedMoE(FusedMoE): local_num_experts = torch.sum(self.expert_map != -1) \ if self.expert_map is not None else num_experts + self.moe_load = None + + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + moe_quant_params = { "num_experts": local_num_experts, "hidden_size": hidden_size, @@ -354,34 +329,27 @@ class AscendFusedMoE(FusedMoE): # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) - self.token_dispatcher = None - ep_size = (get_ep_group().world_size if - vllm_config.parallel_config.enable_expert_parallel else 1) - from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - setup_token_dispatchers - setup_token_dispatchers( - ep_size, - top_k=self.top_k, - num_experts=self.global_num_experts, - num_global_redundant_experts=self.global_redundant_expert_num, - num_local_experts=self.local_num_experts) + self.moe_config.tp_group = get_tp_group() + self.moe_config.dp_group = get_dp_group() + self.moe_config.ep_group = get_ep_group() + self.moe_config.mc2_group = get_mc2_group() + self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(self.dp_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - return buffer + setup_moe_comm_method(self.moe_config) + + def update_expert_map(self, new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.logical_to_physical_map + + def clear_moe_load(self): + if self.moe_load is not None: + self.moe_load.zero_() def forward(self, hidden_states: torch.Tensor, @@ -391,8 +359,7 @@ class AscendFusedMoE(FusedMoE): top_k: Optional[int] = None, shared_experts: Optional[Any] = None, gate=None, - replace_allreduce: bool = False, - _metadata_for_padding: Optional[MetadataForPadding] = None): + replace_allreduce: bool = False): assert self.quant_method is not None @@ -401,10 +368,7 @@ class AscendFusedMoE(FusedMoE): else: real_top_k = self.top_k - num_tokens, hidden_size = hidden_states.shape - forward_context = get_forward_context() - fused_moe_state = forward_context.fused_moe_state mc2_mask = forward_context.mc2_mask # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None @@ -413,74 +377,16 @@ class AscendFusedMoE(FusedMoE): # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce shared_hidden_states = shared_experts(hidden_states) - mc2_mask = forward_context.mc2_mask - - enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill - tp_size = get_tensor_model_parallel_world_size() - if enable_sp: - tp_rank = get_tensor_model_parallel_rank() - mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask - chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0) - mc2_mask = chunk_mc2_mask[tp_rank] + if forward_context.sp_enabled: replace_allreduce = True - if (fused_moe_state not in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ] and not replace_allreduce): - if fused_moe_state in {FusedMoEState.MC2}: - padding_size = forward_context.padded_num_tokens - else: - # TODO: Determine if we can remove the padding - padding_size = tp_size - if num_tokens < padding_size and not self.enable_shared_expert_dp: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, padding_size - num_tokens)) - router_logits = nn.functional.pad( - router_logits, (0, 0, 0, padding_size - num_tokens)) - if tp_size > 1: - tp_rank = get_tensor_model_parallel_rank() - if not self.enable_shared_expert_dp: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] - - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) - mc2_mask = chunk_mc2_mask[tp_rank] - - if self.dp_size > 1: - if fused_moe_state == FusedMoEState.AllGather: - # NOTE: When in torchair graph, it has been padded in model_runner_v1 - max_tokens_across_dp = forward_context.max_tokens_across_dp - if num_tokens < max_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_tokens_across_dp - num_tokens)) - if not self.rm_router_logits: - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_tokens_across_dp - num_tokens)) - hidden_states = get_dp_group().all_gather(hidden_states, 0) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = get_dp_group().all_gather(router_logits, 0) - - elif fused_moe_state == FusedMoEState.NaiveMulticast: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_cpu) + hidden_states, router_logits = forward_context.moe_comm_method.prepare( + hidden_states=hidden_states, + router_logits=router_logits, + enable_shared_expert_dp=self.enable_shared_expert_dp, + rm_router_logits=self.rm_router_logits, + replace_allreduce=replace_allreduce, + gate=gate) # Matrix multiply. e_hidden_states = self.quant_method.apply( @@ -503,53 +409,27 @@ class AscendFusedMoE(FusedMoE): global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=None, mc2_mask=mc2_mask, - token_dispatcher=self.token_dispatcher, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, ) + group_list_type = None + if shared_experts: - if isinstance(e_hidden_states, tuple): + if isinstance(e_hidden_states, + tuple) and len(e_hidden_states) == 2: e_hidden_states, shared_hidden_states = e_hidden_states - if (fused_moe_state not in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ] and not replace_allreduce and not self.enable_shared_expert_dp): - if tp_size > 1: - dist.all_gather(list(chunk_hidden_states), e_hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - dispose_tensor(e_hidden_states) - else: - final_hidden_states = e_hidden_states - if num_tokens < padding_size: - final_hidden_states = final_hidden_states[:num_tokens] - elif self.dp_size > 1 and not self.enable_shared_expert_dp: - if fused_moe_state == FusedMoEState.NaiveMulticast: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - final_hidden_states = get_dp_group().all_reduce( - e_hidden_states) - final_hidden_states = final_hidden_states[start:end, :] - dispose_tensor(e_hidden_states) - elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = data_parallel_reduce_scatter( - e_hidden_states, dim=0) - final_hidden_states = final_hidden_states[:num_tokens] - dispose_tensor(e_hidden_states) - else: - final_hidden_states = e_hidden_states - else: - final_hidden_states = e_hidden_states + if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 3: + e_hidden_states, group_list_type, expert_tokens = e_hidden_states - if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ]: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.dynamic_eplb and group_list_type is not None: + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + + final_hidden_states = forward_context.moe_comm_method.finalize( + hidden_states=e_hidden_states, + reduce_results=(not self.all_reduce_merge)) if shared_experts: return final_hidden_states, shared_hidden_states diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 4f0b550..3dfca53 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,50 +15,124 @@ # This file is a part of the vllm-ascend project. # -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, cast import torch -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm -class AddRMSNormW8A8Quant(RMSNorm): - # Fuse AddRmsNorm and W8A8 quantization ops together +def _addrmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: torch.Tensor, + layer: Optional[torch.nn.Module] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + from vllm_ascend.utils import is_310p + + if layer is not None and not is_310p(): + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + epsilon=self.variance_epsilon) + else: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + torch.ops.vllm.maybe_wait_prefetch_done(x) + return x, residual + + +class AscendRMSNorm(RMSNorm): + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + assert x.size(0) == residual.size(0) + x, residual = _addrmsnorm_forward_oot( + self, x, residual, self.next_need_quant_fusion_linear) + return x, residual + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + return x + + @property + def next_need_quant_fusion_linear(self): + try: + forward_context = get_forward_context() + if not forward_context.addrmsnorm_quant_fusion_enabled or \ + forward_context.layer_idx == forward_context.num_hidden_layers: + return None + except AssertionError: + return None + + next_linear = None + model_instance = forward_context.model_instance + layer_idx = forward_context.layer_idx + fusion_linear = forward_context.fusion_linear + next_linear = None + if fusion_linear == "qkv_dense": + next_linear = model_instance.model.layers[ + layer_idx].self_attn.qkv_proj + forward_context.fusion_linear = "gate_up_dense" + elif fusion_linear == "gate_up_dense": + next_linear = model_instance.model.layers[ + layer_idx].mlp.gate_up_proj + forward_context.fusion_linear = "qkv_dense" + # if prefetch_mlp_weight enabled, following accumulation operation + # does not need to be repeated + if not forward_context.prefetch_mlp_enabled: + forward_context.layer_idx += 1 + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + if next_linear is not None and \ + not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): + next_linear = None + return next_linear + + +class AscendQuantRMSNorm(AscendRMSNorm): def __init__( self, hidden_size: int, - layer: torch.nn.Module, eps: float = 1e-6, var_hidden_size: Optional[int] = None, has_weight: bool = True, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - self.layer = layer + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) - def forward( + def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - import torch_npu - + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: - x, _, residual = torch_npu.npu_add_rms_norm_quant( - x, - residual, - self.weight, - self.layer.aclnn_input_scale, - self.layer.aclnn_input_offset, - epsilon=self.variance_epsilon) - return x, residual - - x, residual = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - return x + x, residual = super().forward_oot(x, residual) + return x.add_(self.bias), residual + return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) -class AscendRMSNorm(RMSNorm): +class AscendGemmaRMSNorm(GemmaRMSNorm): def forward_oot( self, @@ -73,13 +147,13 @@ class AscendRMSNorm(RMSNorm): orig_dtype = residual.dtype x = x + residual.to(x.dtype) residual = x.to(orig_dtype) - x, _ = torch_npu.npu_rms_norm(x, self.weight, + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) else: x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) + x, residual, 1.0 + self.weight, self.variance_epsilon) return x, residual - x, residual = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, + self.variance_epsilon) return x diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index e2f427e..51399cc 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -1,45 +1,159 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -This file is a part of the vllm-ascend project. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +To customize linear communication groups or forward of classes in this file, +extend new linear operations in linear_op.py. +The classes in this file should not be modified, including AscendQKVParallelLinear, +AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear, +AscendRowParallelLinear and AscendColumnParallelLinear. """ from typing import Optional, Union import torch +import torch.nn as nn from torch.nn.parameter import Parameter -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, - ColumnParallelLinear, - LinearBase, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.distributed import divide +from vllm.model_executor.layers.linear import ( # noqa + WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, + MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase, + RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import ( - get_mlp_tensor_model_parallel_rank, - get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group) +from vllm_ascend.ops.linear_op import (get_column_parallel_op, + get_row_parallel_op) -class AscendMlpColumnParallelLinear(ColumnParallelLinear): - """Linear layer with column parallelism. +# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group +class AscendLinearBase(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + nn.Module.__init__(self) + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.quant_config = quant_config + self.prefix = prefix + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) + self.return_bias = return_bias + self.disable_tp = disable_tp + + +class AscendQKVParallelLinear(QKVParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + self.custom_op, _, tp_size = get_column_parallel_op( + disable_tp, prefix, self) + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + AscendColumnParallelLinear.__init__(self, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) + + +class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. Use the MLP tensor parallelism group in the MLP module, and the original TP group in other modules. @@ -48,73 +162,46 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear): def __init__( self, input_size: int, - output_size: int, + output_sizes: list[int], bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[list[int]] = None, prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): - # Divide the weight matrix along the last dimension. - if prefix.find("gate_up_proj") != -1: - self.tp_size = get_mlp_tensor_model_parallel_world_size() - self.tp_rank = get_mlp_tensor_model_parallel_rank() - self.enable_mlp_optimze = True - else: - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.enable_mlp_optimze = False - self.input_size_per_partition = input_size - self.output_size_per_partition = divide(output_size, self.tp_size) - self.output_partition_sizes = [self.output_size_per_partition] - # If QKV or MergedColumn, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = [ - divide(output_size, self.tp_size) - for output_size in self.output_sizes - ] - LinearBase.__init__(self, - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) + self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op( + disable_tp, prefix, self) + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + self.output_sizes = output_sizes + assert all(output_size % self.tp_size == 0 + for output_size in output_sizes) + AscendColumnParallelLinear.__init__(self, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) - self.gather_output = gather_output + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) - if output_sizes is None: - output_sizes = [output_size] - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) - if bias: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) + return super().forward(input_) -class AscendMlpRowParallelLinear(RowParallelLinear): +class AscendRowParallelLinear(RowParallelLinear): """Linear layer with row parallelism. Use the MLP tensor parallelism group in the MLP module, and the original TP group in other modules. @@ -133,28 +220,25 @@ class AscendMlpRowParallelLinear(RowParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): - if prefix.find("down_proj") != -1: - self.tp_size = get_mlp_tensor_model_parallel_world_size() - self.tp_rank = get_mlp_tensor_model_parallel_rank() - self.enable_mlp_optimze = True - else: - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.enable_mlp_optimze = False + self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op( + disable_tp, prefix, self) + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # Divide the weight matrix along the first dimension. self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] - LinearBase.__init__(self, - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -184,66 +268,22 @@ class AscendMlpRowParallelLinear(RowParallelLinear): else: self.register_parameter("bias", None) + if self.custom_op is not None: + self.custom_op.update_attrs() + def forward( self, input_, + is_prefill: bool = True, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - if self.enable_mlp_optimze: - tp_rank = get_mlp_tensor_model_parallel_rank() - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_mlp_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 - or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - output = get_mlp_tp_group().reduce_scatter(output_parallel, 0) - # output = output[:num_tokens,:] - # dispose_tensor(output_parallel) - else: - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + if self.custom_op is not None: + return self.custom_op.apply(input_) - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 - or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias + return super().forward(input_) -class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear): - """Packed linear layers with column parallelism. - - Similar to ColumnParallelLinear, but the weight matrix is concatenated - along the output dimension. When the weight matrix is loaded, the - different partitions are sharded separately. +class AscendColumnParallelLinear(ColumnParallelLinear): + """Linear layer with column parallelism. Use the MLP tensor parallelism group in the MLP module, and the original TP group in other modules. @@ -252,58 +292,76 @@ class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear): def __init__( self, input_size: int, - output_sizes: list[int], + output_size: int, bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): - self.output_sizes = output_sizes - if prefix.find("gate_up_proj") != -1: - self.tp_size = get_mlp_tensor_model_parallel_world_size() - self.tp_rank = get_mlp_tensor_model_parallel_rank() - self.enable_mlp_optimze = True + self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op( + disable_tp, prefix, self) + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) else: - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.enable_mlp_optimze = False - assert all(output_size % self.tp_size == 0 - for output_size in output_sizes) - AscendMlpColumnParallelLinear.__init__(self, - input_size=input_size, - output_size=sum(output_sizes), - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - return_bias=return_bias) + self.register_parameter("bias", None) + + if self.custom_op is not None: + self.custom_op.update_attrs() def forward( self, input_, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - bias = self.bias if not self.skip_bias_add else None - # self.global_batch_size = vllm_config.scheduler_config.max_num_seqs - # Matrix multiply. - assert self.quant_method is not None - if self.enable_mlp_optimze: - input2_ = get_mlp_tp_group().all_gather(input_, 0) - output = self.quant_method.apply(self, input2_, bias) - else: - output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel + if self.custom_op is not None: + return self.custom_op.apply(input_) - output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output - return output, output_bias + return super().forward(input_) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py new file mode 100644 index 0000000..819af72 --- /dev/null +++ b/vllm_ascend/ops/linear_op.py @@ -0,0 +1,459 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file extends the functionality of linear operations by encapsulating custom +communication groups and forward functions into classes (linear ops). + +Current class inheritance structure: +CustomTensorParallelOp +├── CustomColumnParallelOp +│ ├── MLPColumnParallelOp +│ ├── DenseOptimMergedColumnParallelOp +│ └── DenseOptimQKVParallelOp +└── CustomRowParallelOp + ├── MLPRowParallelOp + ├── OProjRowParallelOp + ├── MatmulAllreduceRowParallelOp + └── DenseOptimRowParallelOp + +How to extend a new linear op? Taking column parallel op as an example: +1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp +2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method +3. Override the apply method according to requirements, which will replace the original linear.forward +4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments +Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. +""" + +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch_npu +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from vllm.distributed import split_tensor_along_last_dim +from vllm.distributed.parallel_state import get_tp_group + +from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, + get_otp_group) +from vllm_ascend.utils import (dense_optim_enable, enable_sp, + matmul_allreduce_enable, mlp_tp_enable, + oproj_tp_enable) + + +class CustomTensorParallelOp: + + def __init__(self, layer): + self.layer = layer + self.bias = None + self.skip_bias_add = None + self.return_bias = None + self.quant_method = None + + # Custom communication group, while determining weight sharding + @property + def comm_group(self): + return get_tp_group() + + @property + def tp_rank(self): + return self.comm_group.rank_in_group + + @property + def tp_size(self): + return self.comm_group.world_size + + # Update the attributes required by apply(), obtaining them from the layer. + # Call this after the layer completes its initialization, specifically at the end of layer.init(). + def update_attrs(self): + if hasattr(self.layer, "bias"): + self.bias = self.layer.bias + self.skip_bias_add = self.layer.skip_bias_add + self.return_bias = self.layer.return_bias + self.quant_method = self.layer.quant_method + self.prefix = self.layer.prefix + + def apply_impl(self, input_): + raise NotImplementedError + + # Replace layer.forward to customize the layer computation process. + def apply(self, input_): + output, output_bias = self.apply_impl(input_) + if not self.return_bias: + return output + return output, output_bias + + +class CustomColumnParallelOp(CustomTensorParallelOp): + + def __init__(self, layer): + super().__init__(layer) + self.gather_output = None + + def update_attrs(self): + super().update_attrs() + self.gather_output = self.layer.gather_output + + +class CustomRowParallelOp(CustomTensorParallelOp): + + def __init__(self, layer): + super().__init__(layer) + self.reduce_results = None + self.input_is_parallel = None + self.input_size_per_partition = None + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.reduce_results = self.layer.reduce_results + self.input_size_per_partition = self.layer.input_size_per_partition + + def apply(self, input_): + output, output_bias = self.apply_impl(input_) + if dense_optim_enable(): + torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) + if not self.return_bias: + return output + return output, output_bias + + +class MLPColumnParallelOp(CustomColumnParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + return get_mlp_tp_group() + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + input_parallel = self.comm_group.all_gather(input_, 0) + output = self.quant_method.apply(self.layer, input_parallel, bias) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class SequenceMergedColumnParallelOp(CustomColumnParallelOp): + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + output_parallel = self.quant_method.apply(self.layer, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class SequenceQKVParallelOp(CustomColumnParallelOp): + + def __init__(self, layer, prefix): + super().__init__(layer) + self.prefix = prefix + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + layer_num = self.prefix.split('.')[2] + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, layer_num != '0') + output_parallel = self.quant_method.apply(self.layer, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class MLPRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + return get_mlp_tp_group() + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + assert self.quant_method is not None + bias_ = None if (self.tp_rank > 0 + or self.skip_bias_add) else self.layer.bias + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + output = self.comm_group.reduce_scatter(output_parallel, 0) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class OProjRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + return get_otp_group() + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Prepare tensors for all-to-all communication + local_batch_size = input_parallel.size(0) + chunk_size = self.input_size_per_partition + total_batch_size = local_batch_size * self.tp_size + + # Reshape tensor for efficient cross-device transfer: + # [batch, dim] -> [tp_size, batch, chunk] -> flattened + send_buf = (input_parallel.reshape(-1, + self.tp_size, chunk_size).transpose( + 0, 1).contiguous().view(-1)) + + # Create receive buffer + recv_buf = torch.empty(total_batch_size * chunk_size, + dtype=input_parallel.dtype, + device=input_parallel.device) + + # Perform all-to-all communication + dist.all_to_all_single(recv_buf, + send_buf, + group=self.comm_group.device_group) + input_parallel = recv_buf.view(total_batch_size, chunk_size) + + # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1 + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + + # otp-specific: Combine partial results across devices + output = self.comm_group.reduce_scatter(output_parallel, dim=0) + output = output.view(input_.shape[0], self.layer.output_size) + + # Handle bias return based on configuration + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.input_size_per_partition = self.layer.input_size_per_partition + + +class MatmulAllreduceRowParallelOp(CustomRowParallelOp): + _HCOMM_INFO = None + + def __init__(self, layer): + super().__init__(layer) + self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + """Calculate the output tensor of forward by considering + fusing communication and computation.""" + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + if self.reduce_results and self.tp_size > 1: + output = torch_npu.npu_mm_all_reduce_base(input_parallel, + self.weight_t, + self.hcomm_info, + bias=bias_) + else: + assert self.quant_method is not None + output = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + @classmethod + def get_hcomm_info(cls, group: ProcessGroup) -> str: + """Get the HCCL communication information for the given group.""" + if cls._HCOMM_INFO is not None: + return cls._HCOMM_INFO + + rank = torch.distributed.get_rank(group) + if torch.__version__ > "2.0": + global_rank = torch.distributed.get_global_rank(group, rank) + cls._HCOMM_INFO = group._get_backend( + torch.device("npu")).get_hccl_comm_name(global_rank) + else: + cls._HCOMM_INFO = group.get_hccl_comm_name(rank) + return cls._HCOMM_INFO + + def update_attrs(self): + super().update_attrs() + self.weight_t = self.layer.weight.t() + + +class SequenceRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer, prefix): + super().__init__(layer) + self.prefix = prefix + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + assert self.quant_method is not None + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + + if self.tp_size == 1 or not self.reduce_results: + output = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + else: + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.reduce_results = self.layer.reduce_results + + +def get_column_parallel_op( + disable_tp, prefix, layer +) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, + SequenceQKVParallelOp]], int, int]: + if disable_tp: + return None, 0, 1 + + custom_op: Optional[Union[ + MLPColumnParallelOp, + SequenceMergedColumnParallelOp, + SequenceQKVParallelOp, + ]] = None + if "gate_up_proj" in prefix and mlp_tp_enable(): + custom_op = MLPColumnParallelOp(layer) + elif "gate_up_proj" in prefix and enable_sp(): + custom_op = SequenceMergedColumnParallelOp(layer) + elif enable_sp(): + custom_op = SequenceQKVParallelOp(layer, prefix) + + if custom_op is not None: + return custom_op, custom_op.tp_rank, custom_op.tp_size + + return None, get_tp_group().rank_in_group, get_tp_group().world_size + + +def get_row_parallel_op( + disable_tp, prefix, layer +) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp, + MatmulAllreduceRowParallelOp, + SequenceRowParallelOp]], int, int]: + if disable_tp: + return None, 0, 1 + + custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, + MatmulAllreduceRowParallelOp, + SequenceRowParallelOp]] = None + if "down_proj" in prefix and mlp_tp_enable(): + custom_op = MLPRowParallelOp(layer) + elif "o_proj" in prefix and oproj_tp_enable(): + custom_op = OProjRowParallelOp(layer) + elif matmul_allreduce_enable(): + custom_op = MatmulAllreduceRowParallelOp(layer) + elif enable_sp(): + custom_op = SequenceRowParallelOp(layer, prefix) + + if custom_op is not None: + return custom_op, custom_op.tp_rank, custom_op.tp_size + + return None, get_tp_group().rank_in_group, get_tp_group().world_size diff --git a/vllm_ascend/ops/moe/__init__.py b/vllm_ascend/ops/moe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/ops/comm_utils.py b/vllm_ascend/ops/moe/comm_utils.py similarity index 55% rename from vllm_ascend/ops/comm_utils.py rename to vllm_ascend/ops/moe/comm_utils.py index e893049..b8952a9 100644 --- a/vllm_ascend/ops/comm_utils.py +++ b/vllm_ascend/ops/moe/comm_utils.py @@ -1,5 +1,7 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. +# This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# This file is a part of the vllm-ascend project. +# import torch import torch.distributed import torch.distributed as dist @@ -60,3 +62,52 @@ def async_all_to_all(input_, group=group, async_op=True) return input_, a2a_out, handle + + +def _gather_along_first_dim(input_, group, output_split_sizes=None): + """Gather tensors and concatenate along the first dimension. + + Args: + input_tensor (torch.Tensor): + A tensor to be gathered. + output_split_sizes (List[int], optional): + A list specifying the sizes of the output splits along the first dimension. + If None, equal splitting is assumed. Default: None. + + Returns: + torch.Tensor: Gathered tensor. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + if output_split_sizes is None: + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + else: + dim_size[0] = sum(output_split_sizes) + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + output_tensor_list = list( + torch.split(output, output_split_sizes, dim=0)) + torch.distributed.all_gather(output_tensor_list, input_, group=group) + + return output + + +def gather_from_sequence_parallel_region( + input_, + group, + output_split_sizes=None, +): + """Wrapper for autograd function: forward: AG, backward: RS """ + return _gather_along_first_dim(input_, group, output_split_sizes) \ No newline at end of file diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py similarity index 100% rename from vllm_ascend/ops/layers/experts_selector.py rename to vllm_ascend/ops/moe/experts_selector.py diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py new file mode 100644 index 0000000..3d800e4 --- /dev/null +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -0,0 +1,459 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from abc import ABC, abstractmethod + +import torch +import torch.distributed as dist +import torch.nn as nn +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_dp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_ascend.utils import vllm_version_is + + +class FusedMoEPrepareAndFinalize(ABC): + """ + Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization + in distributed environments. Subclasses implement specific communication strategies + (e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing, + broadcasting, and reduction across TP/DP/EP groups. + + Attributes: + moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info, + sizes, ranks, and communication settings. + """ + + def __init__(self, moe_config: FusedMoEConfig): + self.moe_config = moe_config + + @abstractmethod + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare tensors before MoE computation. May involve: + - Padding to align communication boundaries + - Slicing across tensor-parallel ranks + - Broadcasting across data-parallel ranks + - Recomputing router logits if needed + + Args: + hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size] + router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] + enable_shared_expert_dp (bool): Skip DP communication for shared experts + rm_router_logits (bool): Discard input router_logits and recompute via gate + replace_allreduce (bool): Bypass default all-reduce behavior + gate (nn.Module, optional): Gate network to recompute router_logits if needed + + Returns: + Tuple of: + - processed hidden_states (may be padded/sliced/broadcasted) + - processed router_logits (may be recomputed or broadcasted) + - optional communication mask (e.g., mc2_mask for sparse ops) + """ + raise NotImplementedError("Prepare not implemented.") + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalize MoE output. May involve: + - Gathering sliced tensors across TP ranks + - Reducing or scattering across DP ranks + - Unpadding to original token count + - Applying all-reduce across TP/EP if requested + + Args: + hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced + reduce_results (bool): Whether to apply all-reduce across TP/EP groups + + Returns: + torch.Tensor: Final output with shape [original_num_tokens, hidden_size] + """ + raise NotImplementedError("Finalize function not implemented.") + + +class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using MC2 (Memory-Centric Communication). + Designed for Ascend or environments requiring explicit padding and slicing control. + Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. + """ + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + """ + Restore original TP configuration. + vLLM flattens TP and DP into a single dimension; this method recovers + the true TP world size and rank for correct tensor slicing. + """ + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch `mc2_mask` and target padding length from forward context. + 2. Pad `hidden_states` and `router_logits` to target length if needed. + 3. If TP > 1, split tensors along token dimension and select current TP rank's slice. + 4. Split and return corresponding `mc2_mask`. + + Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + forward_context = get_forward_context() + mc2_mask = forward_context.mc2_mask + if self.tp_size > 1: + # Also slice mc2_mask + split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) + mc2_mask = split_mc2_mask[self.tp_rank] + + if not self.replace_allreduce: + self.num_tokens, _ = hidden_states.shape + target_pad_length = forward_context.padded_num_tokens + pad_size = target_pad_length - self.num_tokens + + # Pad if necessary (unless shared expert DP is enabled) + if pad_size > 0 and not self.enable_shared_expert_dp: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + # Slice across TP ranks + if self.tp_size > 1 and not self.enable_shared_expert_dp: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + self.split_hidden_states = split_hidden_states # Save for finalize + + return hidden_states, router_logits, mc2_mask + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor. + 2. Unpad to original token count if padding was applied. + 3. Return tensor with shape [original_num_tokens, hidden_size]. + + Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + # All-gather across TP group + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + # TODO: It is a quick bugfix for the memory explosion issue in eager mode. + # If the cache is not cleared after `self.split_hidden_states` is created, + # it can lead to the memory explosion in eager mode. + del self.split_hidden_states + + # Unpad if necessary + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-to-All style slicing. + Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. + Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). + """ + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + """Restore original TP configuration (same as MC2).""" + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Pad hidden_states and router_logits to next multiple of TP size. + 2. If TP > 1, split along token dim and select current TP rank's slice. + 3. Save splits for later all-gather in finalize. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, None) — no mask used in All2All. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + + if not (self.replace_allreduce or self.enable_shared_expert_dp): + self.num_tokens, _ = hidden_states.shape + pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + self.split_hidden_states = split_hidden_states + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices to reconstruct full tensor. + 2. Unpad to original token count. + 3. Return [original_num_tokens, hidden_size] tensor. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + # TODO: It is a quick bugfix for the memory explosion issue in eager mode. + # If the cache is not cleared after `self.split_hidden_states` is created, + # it can lead to the memory explosion in eager mode. + del self.split_hidden_states + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-Gather + Reduce-Scatter. + Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after. + Uses `max_tokens_across_dp` from forward_context for padding alignment. + """ + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch max token count across DP group from forward context. + 2. Pad local tensors to that size. + 3. All-gather across DP group to form global input tensor. + 4. Optionally recompute router_logits using gate if `rm_router_logits=True`. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + self.enable_shared_expert_dp = enable_shared_expert_dp + + if self.moe_config.dp_size > 1: + forward_context = get_forward_context() + max_tokens_across_dp = forward_context.max_tokens_across_dp + + self.num_tokens = hidden_states.shape[0] + pad_size = max_tokens_across_dp - self.num_tokens + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + if not rm_router_logits: + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + # All-gather across DP group + hidden_states = self.moe_config.dp_group.all_gather( + hidden_states, 0) + if rm_router_logits: + router_logits, _ = gate(hidden_states) # Recompute globally + else: + router_logits = self.moe_config.dp_group.all_gather( + router_logits, 0) + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If DP > 1 and not shared expert, reduce-scatter output across DP group. + 2. Slice to original local token count. + 3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. + + Returns: + Tensor with shape [original_local_num_tokens, hidden_size] + """ + if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: + hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) + hidden_states = hidden_states[:self.num_tokens] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using Naive Multicast (point-to-point broadcast). + Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others. + Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries. + """ + + def _naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + """ + Naive multicast implementation: + 1. Create global buffer sized by total tokens across DP. + 2. Current rank copies its slice into its designated buffer region. + 3. Each rank broadcasts its slice to all others via P2P. + + Args: + x (torch.Tensor): Local tensor [local_tokens, hidden_size] + cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank + + Returns: + torch.Tensor: Global tensor [total_tokens, hidden_size] + """ + assert len(x.shape) == 2, "Input must be 2D [tokens, features]" + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + # Copy local slice into buffer + start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.moe_config.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank] + buffer[start:end, :].copy_(x) + + # Broadcast each slice to all ranks + for idx in range(self.moe_config.dp_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + return buffer + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch cumulative token boundaries from forward context. + 2. Multicast hidden_states and router_logits to form global tensors. + 3. Optionally recompute router_logits globally if `rm_router_logits=True`. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + self.enable_shared_expert_dp = enable_shared_expert_dp + + if self.moe_config.dp_size > 1: + if vllm_version_is("0.10.2"): + self.cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + else: + self.cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_sp(1) + hidden_states = self._naive_multicast(hidden_states, + self.cu_tokens_across_dp_cpu) + if rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = self._naive_multicast( + router_logits, self.cu_tokens_across_dp_cpu) + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If DP > 1 and not shared expert: + - All-reduce across DP + - Slice to current rank's token range using cu_tokens_across_dp_cpu + 2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. + + Returns: + Tensor with shape [local_num_tokens, hidden_size] + """ + if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: + start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[ + self.moe_config.dp_rank - 1] + end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank] + hidden_states = get_dp_group().all_reduce( + hidden_states) # Sum across DP + hidden_states = hidden_states[start:end, :] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py new file mode 100644 index 0000000..555189e --- /dev/null +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import torch +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( + FusedMoEPrepareAndFinalizeWithAll2All, + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, + FusedMoEPrepareAndFinalizeWithNaiveMulticast) +from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, + TokenDispatcherWithMC2, + TokenDispatcherWithMoge) + +_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} + + +def get_moe_comm_method( + moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]: + return _MoECommMethods.get(moe_comm_type) + + +def setup_moe_comm_method(moe_config): + _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) + _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) + _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) + _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( + moe_config) + + +class MoECommMethod(ABC): + """Base class for MoE communication methods.""" + + def __init__(self, moe_config: FusedMoEConfig): + self.model_type = get_current_vllm_config( + ).model_config.hf_config.model_type + self.moe_config = moe_config + self.mc2_mask = None + + self.token_dispatcher = self._get_token_dispatcher() + self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( + ) + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( + hidden_states, router_logits, enable_shared_expert_dp, + rm_router_logits, replace_allreduce, gate) + self.mc2_mask = mc2_mask + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + hidden_states = self.fused_moe_prepare_finalize.finalize( + hidden_states, reduce_results) + return hidden_states + + def fused_experts( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + # For TorchAir graph + is_torchair: bool = False, + # For Cube/Vector parallel + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + # For load balance + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + need_trans: bool = False, + dynamic_eplb: bool = False): + # Check constraints + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + moe_comm_method = get_forward_context().moe_comm_method + assert moe_comm_method is not None, "Missing communication context" + + results = self.token_dispatcher.token_dispatch( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + mc2_mask=self.mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + with_quant=use_int8_w8a8 or use_int4_w4a8) + + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") + + mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=expert_tokens, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=topk_scales, + with_quant=use_int8_w8a8 + or use_int4_w4a8, + fusion=use_int8_w8a8, + need_trans=need_trans) + + final_hidden_states = self.token_dispatcher.token_combine( + hidden_states=mlp_output) + + if dynamic_eplb: + return (final_hidden_states, group_list_type, expert_tokens) + + return final_hidden_states + + @abstractmethod + def _get_token_dispatcher(self): + raise NotImplementedError( + "_get_token_dispatcher function not implemented.") + + @abstractmethod + def _get_fused_moe_prepare_finalize(self): + raise NotImplementedError( + "_get_fused_moe_prepare_finalize function not implemented.") + + +class AllGatherCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _get_token_dispatcher(self): + if self.model_type == "PanguProMoE": + return TokenDispatcherWithMoge( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + else: + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + + +class MC2CommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. + 3. `enable_expert_parallel=False` is not supported. + + This implementation uses the MC2 communication method, which is optimized for + Communication and Computation parallelism on Ascend devices. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithMC2() + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + + +class AlltoAllCommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_grouped_matmul` is available. + + This implementation uses all-to-all communication to exchange tokens + between data parallel ranks before and after the MLP computation. It should + have better performance than AllGatherCommImpl when DP size > 1. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAll2AllV( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + + +class NaiveMulticastCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) diff --git a/vllm_ascend/ops/layers/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py similarity index 51% rename from vllm_ascend/ops/layers/moe_mlp.py rename to vllm_ascend/ops/moe/moe_mlp.py index c73e8ea..b74f945 100644 --- a/vllm_ascend/ops/layers/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -18,22 +18,52 @@ from typing import Optional import torch import torch_npu +from torch.nn.functional import pad from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import dispose_tensor, is_310p +def cumsum_group_list(group_list: torch.Tensor, + group_list_type: int, + active_num: int = 0, + expert_num: int = 0) -> torch.Tensor: + if group_list_type not in [0, 1, 2]: + raise ValueError( + f"group_list_type should be in [0, 1, 2], but received {group_list_type}" + ) + + if group_list_type == 0: + return group_list + if group_list_type == 1: + return group_list.cumsum(dim=0) + + experts = pad(group_list[:, 0], (1, 0)) + tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) + cumsum_group_list = torch.full(size=(expert_num, ), + fill_value=active_num, + dtype=group_list.dtype, + device=group_list.device) + + for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): + if end > start: + cumsum_group_list[start:end] = tokens[i] + + return cumsum_group_list + + def quant_apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, w2: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, group_list_type: int = 1, + dynamic_scale: torch.Tensor = None, w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None) -> torch.Tensor: + w2_scale_bias: torch.Tensor = None, + fusion: bool = False) -> torch.Tensor: if dynamic_scale is None: unquantized_hidden_states = hidden_states hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( @@ -47,33 +77,40 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias1, bias2 = None, None _output_dtype = w2_scale.dtype - is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2 + is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: - w1_scale = w1_scale.to(torch.float32) - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) - + if fusion: + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + x=hidden_states, + weight=w1, + group_list=cumsum_group_list(group_list, group_list_type), + weight_scale=w1_scale, + x_scale=pertoken_scale) + else: + if w1_scale.dtype != torch.float32: + w1_scale = w1_scale.to(torch.float32) + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor, [group_list[:1], torch.diff(group_list, dim=0)]) group_list_type = 1 - bias1 = [w1_scale_bias] + bias1 = [w1_scale_bias] if not fusion else w1_scale_bias bias2 = [w2_scale_bias] # TODO w4a8 scene: dynamic acquisition of dtype in the future _output_dtype = torch.bfloat16 - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - scale=[w1_scale], - bias=bias1, - per_token_scale=[pertoken_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) - + if fusion: + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + x=hidden_states, + weight=w1, + bias=bias1, + group_list=cumsum_group_list(group_list, group_list_type), + weight_scale=w1_scale, + x_scale=pertoken_scale) + else: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale.to(w2_scale.dtype)], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -127,17 +172,22 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=_output_dtype)[0] + return hidden_states -def unquant_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1, - topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: - w1 = w1.transpose(1, 2) +def unquant_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + topk_scales: Optional[torch.Tensor] = None, + need_trans: bool = True) -> torch.Tensor: + + if need_trans: + w1 = w1.transpose(1, 2) + w2 = w2.transpose(1, 2) + gate_up_out = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], @@ -155,7 +205,6 @@ def unquant_apply_mlp( if topk_scales is not None: gate_up_out *= topk_scales - w2 = w2.transpose(1, 2) hidden_states = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], @@ -178,7 +227,9 @@ def unified_apply_mlp(hidden_states: torch.Tensor, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, topk_scales: Optional[torch.Tensor] = None, - with_quant: bool = False) -> torch.Tensor: + with_quant: bool = False, + fusion: bool = False, + need_trans: bool = True) -> torch.Tensor: if with_quant: return quant_apply_mlp(hidden_states=hidden_states, w1=w1, @@ -189,11 +240,13 @@ def unified_apply_mlp(hidden_states: torch.Tensor, dynamic_scale=dynamic_scale, group_list_type=group_list_type, w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) + w2_scale_bias=w2_scale_bias, + fusion=fusion) else: return unquant_apply_mlp(hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type, - topk_scales=topk_scales) + topk_scales=topk_scales, + need_trans=need_trans) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py similarity index 75% rename from vllm_ascend/ops/moe_dispatcher/token_dispatcher.py rename to vllm_ascend/ops/moe/token_dispatcher.py index 855faad..b36cc44 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -22,42 +22,17 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Optional import torch import torch_npu from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.distributed.tensor_parallel import \ - gather_from_sequence_parallel_region -from vllm_ascend.ops.comm_utils import async_all_to_all +from vllm_ascend.ops.moe.comm_utils import ( + async_all_to_all, gather_from_sequence_parallel_region) from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version -_Dispatchers: Dict[str, Any] = {} - - -def _register_token_dispatcher(dispatcher: Any): - _Dispatchers[dispatcher.__class__.__name__] = dispatcher - - -def get_token_dispatcher(name: str): - return _Dispatchers.get(name) - - -def setup_token_dispatchers(ep_size: int, **kwargs): - existing_dispatchers = set(_Dispatchers.keys()) - - if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs)) - elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) - elif ep_size >= 16: - if "TokenDispatcherWithAll2AllV" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) - if "TokenDispatcherWithMC2" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithMC2(**kwargs)) - class MoETokenDispatcher(ABC): @@ -90,9 +65,9 @@ class MoETokenDispatcher(ABC): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -158,6 +133,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, + "expert_token_nums_type": 0, } stage1_kwargs = { @@ -189,9 +165,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -215,6 +191,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): if self.with_quant: if shared_experts is not None: + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + shared_act_out = shared_experts.act_fn( (shared_gate_up, shared_dequant_scale)) self.shared_act, self.swiglu_out_scale = \ @@ -224,7 +205,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): if shared_experts is not None: shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) self.shared_act = shared_experts.act_fn(shared_gate_up) - group_list_type = 1 + group_list_type = 0 return { "group_list_type": group_list_type, "hidden_states": expand_x, @@ -291,6 +272,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( **kwargs_mc2) + + # these values are no longer used, so they need to be set to None for memory release. + self.output = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.topk_ids = None + self.topk_weights = None + self.mc2_mask = None + self.expert_map = None + if self.shared_experts is None: return hidden_states else: @@ -300,6 +291,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): else: shared_hidden_states, _ = self.shared_experts.down_proj( self.shared_act) + self.shared_act = None + self.shared_experts = None + self.swiglu_out_scale = None return hidden_states, shared_hidden_states @@ -328,9 +322,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -338,8 +332,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.original_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() - dtype = hidden_states.dtype - device = hidden_states.device self.expert_map = expert_map self.topk_weights = topk_weights self.topk_ids = topk_ids @@ -353,144 +345,65 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * \ topk_weights.to(hidden_states.dtype) - if expert_map is not None: - # Generate token indices and flatten - token_indices = (torch.arange( - num_tokens, device=device, - dtype=torch.int64).unsqueeze(1).expand(-1, - self.top_k).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] - - # Filter valid token-expert pairs - self.mask = local_experts_flat != -1 - filtered_weights = torch.where( - self.mask, weights_flat, - torch.zeros_like(weights_flat)).to(dtype) - filtered_experts = torch.where( - self.mask, local_experts_flat, - torch.full_like(local_experts_flat, - self.num_experts_local)).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - self.sorted_token_indices = token_indices[sort_indices] - self.sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(self.num_experts_local + 1, - device=device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), - ones) - token_counts = token_counts[:self.num_experts_local] - - # Rearrange hidden_states - sorted_hidden_states = hidden_states[self.sorted_token_indices] - if self.with_quant: - group_list_type = 1 - expert_tokens = token_counts - else: - expert_tokens = torch.cumsum(token_counts, - dim=0, - dtype=torch.int64) - group_list_type = 0 + global_num_experts = len(expert_map) + mask = (expert_map[topk_ids] != -1) + self.topk_weights = topk_weights * mask + first_expert_idx = get_ep_group( + ).rank_in_group * self.num_experts_local + last_expert_idx = first_expert_idx + self.num_experts_local else: - active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens - sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=active_num) + first_expert_idx = 0 + last_expert_idx = self.num_experts_local + global_num_experts = self.num_experts_local - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, self.num_experts_local) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 0 + sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = ( + torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + active_num=num_tokens * self.top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=1 if self.with_quant else -1, + )) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 # `count` mode return { "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, "group_list": expert_tokens, + "dynamic_scale": pertoken_scale if self.with_quant else None, } def token_combine(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): assert self.original_shape is not None - dtype = hidden_states.dtype - device = hidden_states.device - if self.expert_map is not None: - assert self.mask is not None - assert self.sorted_token_indices is not None - assert self.sorted_weights is not None + final_hidden_states = torch_npu.npu_moe_token_unpermute( + permuted_tokens=hidden_states, + sorted_indices=self.expanded_row_idx, + probs=self.topk_weights) + if len(self.original_shape) == 3: + final_hidden_states = final_hidden_states.view(self.original_shape) - weighted_down_out = hidden_states * \ - self.sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros(*self.original_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) - - # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # This created multiple NaN and index_add_ will mix them up which harms accuracy - # remove this mask and filter after it being fixed - num_valid_tokens = self.mask.sum() - valid_token_mask = torch.arange( - 0, self.sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens - valid_output = torch.where( - valid_token_mask, weighted_down_out, - torch.zeros_like(weighted_down_out)).to(dtype) - final_hidden_states.index_add_(0, self.sorted_token_indices, - valid_output) - else: - if self.with_quant: - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=self.topk_weights, - expanded_src_to_dst_row=self.expanded_row_idx, - export_for_source_row=self.topk_ids, - ) - if len(self.original_shape) == 3: - final_hidden_states = final_hidden_states.view( - self.original_shape) - else: - scales = torch.ones_like( - self.topk_weights - ) if self.apply_router_weight_on_input else self.topk_weights - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=scales, - expanded_src_to_dst_row=self.expanded_row_idx, - export_for_source_row=self.topk_ids, - ) + # these values are no longer used, so they need to be set to None for memory release. + self.expert_map = None + self.topk_weights = None + self.topk_ids = None + self.expanded_row_idx = None return final_hidden_states # mypy: disable-error-code="override" -class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): +class TokenDispatcherWithMoge(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) self.apply_router_weight_on_input = False - self.local_ep = 1 - self.local_num_experts = self.num_experts // self.local_ep - self.local_num_group = self.top_k // self.local_ep + self.local_num_experts = self.num_experts // self.ep_size + self.local_num_group = self.top_k // self.ep_size self.bsz = None def token_dispatch(self, @@ -501,23 +414,12 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): - self.apply_router_weight_on_input = apply_router_weight_on_input - if self.apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * \ - topk_weights.to(hidden_states.dtype) - self.bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) @@ -551,7 +453,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): unsorted_hidden_states = hidden_states.index_select( 0, unsorted_topk_ids) final_hidden_states = unsorted_hidden_states.reshape( - self.bsz, self.top_k // self.local_ep, -1).sum(1) + self.bsz, self.top_k // self.ep_size, -1).sum(1) return final_hidden_states @@ -613,9 +515,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -681,9 +583,14 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): output = self._combine_postprocess(permutated_local_input_tokens) + # these values are no longer used, so they need to be set to None for memory release. self.input_splits = None self.output_splits = None self.num_global_tokens_per_local_expert = None + self.topk_weights = None + self.reversed_local_input_permutation_mapping = None + self.reversed_global_input_permutation_mapping = None + self.global_input_tokens_local_experts_indices = None return output @@ -745,6 +652,10 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) + else: + # TODO: This full synchronization can be a performance bottleneck. + # A more granular sync (e.g., blocking D2H copies) should be investigated. + torch.npu.synchronize() return num_tokens_per_local_expert diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py new file mode 100644 index 0000000..438bff1 --- /dev/null +++ b/vllm_ascend/ops/register_custom_ops.py @@ -0,0 +1,201 @@ +import torch +import torch.nn.functional as F +import torch_npu +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context +from vllm.utils import direct_register_custom_op + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_forward_context import MoECommType + + +def _maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return residual + + if x.size(0) != residual.size(0): + sp_enabled = forward_context.sp_enabled + assert sp_enabled is True, ("Currently, this situation only occurs " + "when sp is enabled") + pad_size = forward_context.pad_size + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] + + return residual + + +def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, + label: bool) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return x + + sp_enabled = forward_context.sp_enabled + if sp_enabled and label: + x = tensor_model_parallel_all_gather(x, 0) + pad_size = forward_context.pad_size + if pad_size > 0: + x = x[:-pad_size, :] + return x + + +def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return tensor_model_parallel_all_reduce(x) + + sp_enabled = forward_context.sp_enabled + if sp_enabled: + pad_size = forward_context.pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + return tensor_model_parallel_reduce_scatter(x, 0) + else: + return tensor_model_parallel_all_reduce(x) + + +def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, + prefix: str) -> None: + try: + forward_context = get_forward_context() + except AssertionError: + return + + if not forward_context.prefetch_mlp_enabled: + return + model_instance = forward_context.model_instance + prefetch_stream = forward_context.prefetch_stream + layer_idx = int(prefix.split('.')[2]) + + # start point of gate_up_proj weight prefetch + if prefix.split('.')[-2] == "self_attn": + forward_context.prefetch_mlp_gate_up_proj = True + if forward_context.prefetch_mlp_gate_up_proj: + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE + torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \ + x_dependency, mlp_gate_up_prefetch_size) + return + + +def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, + prefix: str) -> None: + return + + +def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: + try: + forward_context = get_forward_context() + except AssertionError: + return + + if not forward_context.prefetch_mlp_enabled: + return + forward_context.prefetch_mlp_down_proj = True + model_instance = forward_context.model_instance + prefetch_stream = forward_context.prefetch_stream + layer_idx = forward_context.layer_idx + + # start point of down_proj weight prefetch + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE + torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \ + x_dependency, mlp_down_prefetch_size) + forward_context.layer_idx += 1 + return + + +def _maybe_prefetch_mlp_down_proj_impl_fake( + x_dependency: torch.Tensor) -> None: + return + + +def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: + try: + forward_context = get_forward_context() + except AssertionError: + return + + if not forward_context.prefetch_mlp_enabled: + return + if forward_context.prefetch_mlp_gate_up_proj or \ + forward_context.prefetch_mlp_down_proj: + prefetch_stream = forward_context.prefetch_stream + # wait until prefetch done + torch.npu.current_stream().wait_stream(prefetch_stream) + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + return + + +def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: + return + + +def _maybe_all_reduce_tensor_model_parallel_impl( + final_hidden_states: torch.Tensor) -> torch.Tensor: + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + + +direct_register_custom_op(op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: residual, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", + op_func=_maybe_all_gather_and_maybe_unpad_impl, + fake_impl=lambda x, label: x, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_pad_and_reduce", + op_func=_maybe_pad_and_reduce_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj", + op_func=_maybe_prefetch_mlp_gate_up_proj_impl, + fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj", + op_func=_maybe_prefetch_mlp_down_proj_impl, + fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_wait_prefetch_done", + op_func=_maybe_wait_prefetch_done_impl, + fake_impl=_maybe_wait_prefetch_done_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 89e2bc7..9ddf280 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,6 +20,7 @@ from typing import Optional, Tuple import torch import torch_npu +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -37,34 +38,39 @@ def _rope_forward_oot( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None, + is_neox_style: bool, + offsets: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) - neox_style = self.is_neox_style - if is_neox_style_override is not None: - neox_style = is_neox_style_override # adopt custom kernel path for rotary_embedding - if _custom_rotary_embedding_enabled(query, neox_style, + if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size) and not is_310p(): - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, - neox_style, + is_neox_style, ) return query.view(query_shape), key.view(key_shape) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if self.rotary_dim < self.head_size: + if self.cos is not None and \ + self.sin is not None: + # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. + # This method requires head_size and rotary_dim equal 128 and neox_style is True + query = query.contiguous().view(1, query.shape[0], -1, + self.head_size) + key = key.contiguous().view(1, key.shape[0], -1, self.head_size) + torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) + elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size) @@ -80,25 +86,26 @@ def _rope_forward_oot( k_rot, self.head_size, self.cos_sin_cache, - neox_style, + is_neox_style, ) q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) return q, k - # TODO: Remove the contiguous in the future. - query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(key.shape[0], -1) - torch_npu._npu_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - neox_style, - ) - return query.view(query_shape), key.view(key_shape) + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + return query.view(query_shape), key.view(key_shape) class AscendRotaryEmbedding(RotaryEmbedding): @@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, ) -> None: + self.cos = None + self.sin = None super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) @@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding): offsets: Optional[torch.Tensor] = None, is_neox_style_override: Optional[bool] = None, ): - return _rope_forward_oot( - self, - positions, - query, - key, - offsets, - is_neox_style_override, - ) + is_neox_style = self.is_neox_style + if is_neox_style_override is not None: + is_neox_style = is_neox_style_override + forward_context = get_forward_context() + is_first_layer = forward_context.is_first_layer + # Generate cos and sin outside layers to avoid repeated calculation. + if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ + -1] == 128: + if is_first_layer: + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( + 1, 1, 2).chunk(2, dim=-2) + # BSNH + self.cos = cos.view(1, -1, 1, last_dim).contiguous() + self.sin = sin.view(1, -1, 1, last_dim).contiguous() + forward_context.is_first_layer = False + return _rope_forward_oot(self, positions, query, key, is_neox_style, + offsets) class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): @@ -168,8 +188,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): super(DeepseekScalingRotaryEmbedding, self).__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - self._set_cos_sin_cache(seq_len=max_position_embeddings, + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) + self._set_cos_sin_cache(self.max_seq_len, device=NPUPlatform.device_type, dtype=dtype) @@ -275,8 +297,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): return q_embed, k_embed - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len + def _set_cos_sin_cache(self, max_seq_len, device, dtype): dim = self.rotary_dim freq_extra = 1.0 / (self.base**( @@ -297,9 +318,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale @@ -317,16 +336,13 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - self._set_cos_sin_cache(max_seq_len, query.device, query.dtype) + offsets: Optional[torch.Tensor] = None): if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style # calculation method which is also more compute friendly to the ascend machine # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py - neox_style = True + is_neox_style = True if self.is_neox_style is False: b, h_q, d = query.shape query = query.view(b, h_q, d // 2, @@ -334,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets, - neox_style) + q_pe, k_pe = _rope_forward_oot(self, positions, query, key, + is_neox_style, offsets) return q_pe, k_pe diff --git a/vllm_ascend/ops/sigmoid_gating.py b/vllm_ascend/ops/sigmoid_gating.py new file mode 100644 index 0000000..c99799c --- /dev/null +++ b/vllm_ascend/ops/sigmoid_gating.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors + +import os +from typing import Optional + +import torch +from vllm.triton_utils import tl, tldevice, triton + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + + @triton.jit + def div_normal(x, y): + return x / y + + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t + p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t + p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t + else: + p_beta = beta + bos * HV + i_hv + HV * i_t + p_g = g + bos * HV + i_hv + HV * i_t + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t + + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + # b_h *= tl.exp(b_g) + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # print("N: ", N) + # print("T: ", T) + # print("B: ", B) + # print("H: ", H) + # print("HV: ", HV) + # print("K: ", K) + # print("V: ", V) + # print("BK: ", BK) + # print("BV: ", BV) + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state \ No newline at end of file diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 7ad35dc..fe7ee51 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -97,6 +97,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): if params_dtype is None: params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype # Divide the weight matrix along the vocaburaly dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_embeddings_per_partition = divide(self.num_embeddings_padded, @@ -252,3 +253,16 @@ class AscendLogitsProcessor(LogitsProcessor): logits = logits[..., :self.org_vocab_size] return logits + + def forward( + self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + # keep this for version compatibility + sampling_metadata=None, # type: ignore + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + return LogitsProcessor.forward(self, + lm_head, + hidden_states, + embedding_bias=embedding_bias) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 754a344..7d0a232 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -46,6 +46,27 @@ # Need a PR to vllm to support get port from environment. # Future Plan: # Remove those patch when vllm merged them +# 2. `torch.distributed.all_reduce`, `torch.distributed.broadcast` +# Why: +# tensor alignment for 310p +# How: +# rewrite all_reduce and broadcast in torch.distributed +# Related PR (if no, explain why): +# No, not ready yet. +# Future Plan: +# Find a better way to support tensor alignment for 310p without this patch. +# +# ** File: platform/patch_common/patch_multimodal_merge.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.utils._merge_multimodal_embeddings` +# Why: +# '_merge_multimodal_embeddings' func of vllm is incompatible with Ascend. +# How: +# Replace with CPU operation that can be executed asynchronously. +# Related PR (if no, explain why): +# This is a bug by Ascend only. It can' be fixed in vLLM. +# Future Plan: +# Identify this pattern in torch-npu and remove this patch. # # * Worker Patch: # =============== @@ -86,19 +107,15 @@ # - https://github.com/vllm-project/vllm/pull/21591 # Future Plan: # Revert it when vLLM merge #21591 and release new version -# ** File: worker/patch_common/patch_linear.py ** +# ** File: worker/patch_common/patch_logits.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.model_executor.layers.linear.RowParallelLinear` +# 1. `vllm._custom_ops.apply_repetition_penalties` # Why: -# We need to fuse matmul and allreuce in `RowParallelLinear` -# to improve performance. +# apply_repetition_penalties in vLLM use tensor.is_cuda to check if tensor is on cuda. But the value is always True +# on ascend, thus we need to patch apply_repetition_penalties. # How: -# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`. -# In this class, we override the `forward` method to use -# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce. +# Remove the related cuda check in apply_repetition_penalties. # Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm-ascend/pull/1926 +# - this is a bug by Ascend only. It can' be fixed in vLLM. # Future Plan: -# Validate more models in all kinds of scenario, -# if performance is always improved, we can enable this patch by default and remove the env -# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future. +# Fix this bug in torch-npu, bump torch-npu version and remove this patch. diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index f88f2a9..89c74e7 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -15,4 +15,10 @@ # limitations under the License. # +import vllm_ascend.patch.platform.patch_common.patch_config # noqa import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa +import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa +import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa +import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa +import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa +import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa diff --git a/vllm_ascend/patch/platform/patch_common/patch_config.py b/vllm_ascend/patch/platform/patch_common/patch_config.py new file mode 100644 index 0000000..9b6f5c2 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_config.py @@ -0,0 +1,313 @@ +import ast + +import vllm.envs as envs +from transformers import PretrainedConfig +from vllm.config import ModelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.logger import logger + + +# mypy: ignore-errors +@property +def is_deepseek_mla(self: ModelConfig): + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', + 'kimi_k2', 'longcat_flash', 'deepseek_v32'): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == 'eagle': + # if the model is an EAGLE module, check for the + # underlying architecture + return self.hf_text_config.model.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \ + and self.hf_text_config.kv_lora_rank is not None + return False + + +@staticmethod +def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"] + }) + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"] + }) + + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + + if hf_config.model_type == "qwen3_next": + hf_config.model_type = "qwen3_next_mtp" + if hf_config.model_type == "qwen3_next_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["Qwen3NextMTP"] + }) + if hf_config.model_type == "longcat_flash": + hf_config.model_type = "longcat_flash_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["LongCatFlashMTPModel"] + }) + + return hf_config + + +def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + if (self.target_model_config + and self.target_model_config.hf_text_config.model_type + in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe", + "qwen3_next")): + # use the draft model from the same model: + self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided but without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + # TODO: Move this import to the top once `ModelConfig` + # lives in `vllm.config.model`. + from vllm.config import ModelConfig + self.draft_model_config = ModelConfig( + model=self.model, + runner="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + allowed_media_domains=self.target_model_config. + allowed_media_domains, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config.tokenizer_revision, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ('eagle', 'eagle3'): + pass + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type + in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == + "qwen3_next_mtp"): + self.method = "qwen3_next_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Qwen3Next MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type + in ("longcat_flash_mtp")): + self.method = "longcat_flash_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "LongCat MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + else: + self.method = "draft_model" + raise NotImplementedError( + "Speculative decoding with draft model is not " + "supported yet. Please consider using other " + "speculative decoding methods such as ngram, medusa, " + "eagle, or deepseek_mtp.") + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs import SpeculatorsConfig + from vllm.transformers_utils.configs.eagle import EAGLEConfig + + if isinstance(self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig)): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle") + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, "n_predict", + None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(self.num_speculative_tokens) + ]) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval(self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t))) + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) + + +ModelConfig.is_deepseek_mla = is_deepseek_mla +SpeculativeConfig.__post_init__ = __post_init__ +SpeculativeConfig.hf_config_override = hf_config_override diff --git a/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py new file mode 100644 index 0000000..c90ec8e --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py @@ -0,0 +1,100 @@ +# mypy: ignore-errors +import vllm.model_executor.models.config +from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.config import MambaModelConfig +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + +from vllm_ascend.ascend_config import get_ascend_config + + +@classmethod +def verify_and_update_config(cls, vllm_config) -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + logger = init_logger(__name__) + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + ascend_config = get_ascend_config() + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes + + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=model_config.max_model_len, + ).page_size_bytes + + block_alignment_bytes = 64 + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = block_alignment_bytes * cdiv( + mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if (cache_config.block_size is None + or cache_config.block_size < attn_block_size): + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + # compute new attention page size + attn_page_size = \ + cache_config.block_size * attn_page_size_1_token + + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if (cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size): + cache_config.mamba_page_size_padded = (attn_page_size) + mamba_padding_pct = 100 * (attn_page_size - + mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + +vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config diff --git a/vllm_ascend/patch/platform/patch_common/patch_multimodal_merge.py b/vllm_ascend/patch/platform/patch_common/patch_multimodal_merge.py new file mode 100644 index 0000000..c8a1d5c --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_multimodal_merge.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import torch +import vllm +from vllm.model_executor.models.utils import (_embedding_count_expression, + _flatten_embeddings) +from vllm.multimodal import NestedTensors + + +def _merge_multimodal_embeddings( + inputs_embeds: torch.Tensor, + is_multimodal: torch.Tensor, + multimodal_embeddings: NestedTensors, +) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in + ``input_ids``. + + Note: + This updates ``inputs_embeds`` in place. + """ + flattened = _flatten_embeddings(multimodal_embeddings) + try: + inputs_embeds[is_multimodal] = flattened + except RuntimeError as e: + num_expected_tokens = is_multimodal.sum().item() + assert isinstance(num_expected_tokens, int) + + if flattened.shape[0] != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {flattened.shape[0]} " + f"multimodal tokens to {num_expected_tokens} placeholders" + ) from e + else: + raise ValueError("Error during masked scatter operation") from e + + return inputs_embeds + + +vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings diff --git a/vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py b/vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py new file mode 100644 index 0000000..55db190 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_transformers_utils.py @@ -0,0 +1,200 @@ +import vllm.transformers_utils.configs +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from vllm.transformers_utils import config + +logger = logging.get_logger(__name__) + + +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='noaux_tc', + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func='sigmoid', + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +vllm.transformers_utils.configs.__all__.append("DeepseekV3Config") +vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config +config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config" diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 8d206bf..3d233c4 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -15,8 +15,18 @@ # limitations under the License. # +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + import vllm_ascend.patch.worker.patch_common.patch_triton + +# isort: off +import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa +import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa +import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa -import vllm_ascend.patch.worker.patch_common.patch_linear # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa -import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa -import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa +import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa + +# TODO: revert me when triton import is fixed +# import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py b/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py new file mode 100644 index 0000000..6f4ad36 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py @@ -0,0 +1,202 @@ +from typing import List, Optional + +import torch +import vllm +import vllm.envs as envs +from torch import nn +from vllm.attention import Attention, AttentionType, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.selector import backend_name_to_enum +from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import current_platform + +from vllm_ascend.utils import vllm_version_is + + +class AscendAttention(Attention, nn.Module, AttentionLayerBase): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, + use_sfa: bool = False, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + attn_backend: Optional[type[AttentionBackend]] = None, + **extra_impl_args, + ) -> None: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + """ + nn.Module.__init__(self) + AttentionLayerBase.__init__(self) + + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + is_attention_free = False + calculate_kv_scales = False + if num_kv_heads is None: + num_kv_heads = num_heads + assert num_heads % num_kv_heads == 0, \ + f"num_heads ({num_heads}) is not " \ + f"divisible by num_kv_heads ({num_kv_heads})" + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + + self.use_mla = use_mla + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + self.has_sink = extra_impl_args.get("sinks") is not None + + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + if attn_backend is None: + if vllm_version_is("0.10.2"): + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=self.has_sink) + else: + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=self.has_sink) + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **extra_impl_args) + self.backend = backend_name_to_enum(self.attn_backend.get_name()) + self.dtype = dtype + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = not current_platform.opaque_attention_op() + + self.use_output = self.attn_backend.accept_output_buffer + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.attn_type = attn_type + + if kv_sharing_target_layer_name is not None: + validate_kv_sharing_target( + prefix, + kv_sharing_target_layer_name, + compilation_config.static_forward_context, + ) + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + self.query_quant = None + + +vllm.attention.Attention = AscendAttention diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py b/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py new file mode 100644 index 0000000..793fef1 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py @@ -0,0 +1,181 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# mypy: ignore-errors +from functools import cache +from typing import Optional + +import torch +import vllm +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.selector import (backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.platforms import _Backend, current_platform +from vllm.utils import resolve_obj_by_qualname + +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.10.2"): + + def get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool = False, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=has_sink, + ) + + @cache + def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + use_v1: bool = False, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + from vllm.attention.backends.placeholder_attn import \ + PlaceholderAttentionBackend + return PlaceholderAttentionBackend + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, + use_v1, use_mla, use_sfa, has_sink) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}" + ) + return resolve_obj_by_qualname(attention_cls) +else: + + def get_attn_backend( # type: ignore[misc] + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + use_sfa=use_sfa, + has_sink=has_sink, + ) + + @cache + def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool = False, + use_mla: bool = False, + use_sfa: bool = False, + has_sink: bool = False, + ) -> type[AttentionBackend]: + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, + use_v1, use_mla, use_sfa, has_sink) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}" + ) + return resolve_obj_by_qualname(attention_cls) + + +vllm.attention.get_attn_backend = get_attn_backend +vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend diff --git a/vllm_ascend/patch/worker/patch_common/patch_attentionspec.py b/vllm_ascend/patch/worker/patch_common/patch_attentionspec.py new file mode 100644 index 0000000..e1a5ac5 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_attentionspec.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass, fields +from typing import Optional + +import torch +import vllm +from typing_extensions import Self +from vllm.config import VllmConfig +from vllm.utils import cdiv, get_dtype_size +from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager, + spec_manager_map) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec + + +@dataclass(frozen=True) +class AttentionSpec(KVCacheSpec): + num_kv_heads: int + head_size: int + dtype: torch.dtype + use_mla: bool + use_sfa: bool + + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + sfa_bytes = 128 * self.block_size * get_dtype_size( + self.dtype) if self.use_sfa else 0 + + return coef * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + sfa_bytes + + +vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec + + +@dataclass(frozen=True) +class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec): + sliding_window: Optional[int] = None + attention_chunk_size: Optional[int] = None + """ + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @classmethod + def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]: + if len(window_sizes) == 0: + return None + elif len(window_sizes) == 1: + return window_sizes.pop() + else: + raise ValueError( + "All attention layers in the same KV cache group must have the " + "same window size.") + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "FullAttentionSpec.") + + sliding_window = set(spec.sliding_window for spec in specs + if spec.sliding_window is not None) + attention_chunk_size = set(spec.attention_chunk_size for spec in specs + if spec.attention_chunk_size is not None) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + use_mla=specs[0].use_mla, + use_sfa=specs[0].use_sfa, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec.") + assert ( + (merged_spec.sliding_window is not None) + + (merged_spec.attention_chunk_size is not None) <= 1 + ), ("Model with both sliding window layers and chunked local attention " + "layers is not supported.") + return merged_spec + + +spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager}) + +vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec diff --git a/vllm_ascend/patch/worker/patch_common/patch_linear.py b/vllm_ascend/patch/worker/patch_common/patch_linear.py deleted file mode 100644 index 5690ba8..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_linear.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -This file is a part of the vllm-ascend project. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from typing import Optional, Union - -import torch -import torch_npu -import vllm -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter -from vllm.distributed import (get_tensor_model_parallel_rank, - split_tensor_along_last_dim) -from vllm.distributed.parallel_state import get_tp_group -from vllm.logger import logger -from vllm.model_executor.layers.linear import RowParallelLinear - -import vllm_ascend.envs as envs_ascend - -_HCOMM_INFO = None - - -class AscendRowParallelLinear(RowParallelLinear): - """ - AscendRowParallelLinear is a custom implementation of RowParallelLinear - that overrides the forward method to handle Ascend-specific operations. - """ - - def __init__(self, *args, **kwargs): - """Initialize the AscendRowParallelLinear layer. - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - """ - tp_group = get_tp_group().device_group - hcomm_info = self.get_hcomm_info(tp_group) - self.hcomm_info = hcomm_info - super().__init__(*args, **kwargs) - self.weight_t = self.weight.t() - - @staticmethod - def get_hcomm_info(group: ProcessGroup) -> str: - """Get the HCCL communication information for the given group. - - Args: - group (ProcessGroup): The process group for which to get the HCCL communication info. - - Returns: - str: The HCCL communication name for the given group. - """ - global _HCOMM_INFO - if _HCOMM_INFO is not None: - return _HCOMM_INFO - - rank = torch.distributed.get_rank(group) - if torch.__version__ > "2.0": - global_rank = torch.distributed.get_global_rank(group, rank) - _HCOMM_INFO = group._get_backend( - torch.device("npu")).get_hccl_comm_name(global_rank) - - else: - _HCOMM_INFO = group.get_hccl_comm_name(rank) - return _HCOMM_INFO - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - """Forward pass for the AscendRowParallelLinear layer. - - Args: - input_ (torch.Tensor): the input tensor to the layer. - - Returns: - Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - The output tensor after applying the linear transformation, - and optionally the bias if `return_bias` is True. - """ - input_parallel = self.calc_input(input_) - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - output = self.calc_output(input_parallel) - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - def calc_input(self, input_: torch.Tensor) -> torch.Tensor: - """Calculate the input tensor for parallel processing. - - Args: - input_ (torch.Tensor): the input tensor to be processed. - - Returns: - torch.Tensor: The input tensor split along the last dimension - for tensor model parallelism, or the original input if not parallel. - """ - if self.input_is_parallel: - return input_ - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - return splitted_input[tp_rank].contiguous() - - def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor: - """Calculate the output tensor of forward by considering - fusing communication and computation. - - Args: - input_parallel (_type_): the input tensor to be processed in parallel. - - Returns: - torch.Tensor: the output tensor after applying the linear transformation - and optionally handle communication between tensor model parallel ranks. - """ - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - if self.reduce_results and self.tp_size > 1: - output = torch_npu.npu_mm_all_reduce_base(input_parallel, - self.weight_t, - self.hcomm_info, - bias=bias_) - else: - output = self.quant_method.apply(self, input_parallel, bias=bias_) - return output - - -if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE: - logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ") - vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear diff --git a/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py b/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py deleted file mode 100644 index 02d5804..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_lora_embedding.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -import vllm -from torch import nn -from transformers import PretrainedConfig -from vllm.config import LoRAConfig -from vllm.lora.layers import VocabParallelEmbeddingWithLoRA -from vllm.lora.utils import _all_lora_classes - -from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding - - -class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is AscendVocabParallelEmbedding - - -# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515) -_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) -vllm.lora.utils._all_lora_classes = _all_lora_classes diff --git a/vllm_ascend/patch/worker/patch_common/patch_triton.py b/vllm_ascend/patch/worker/patch_common/patch_triton.py new file mode 100644 index 0000000..8904054 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_triton.py @@ -0,0 +1,16 @@ +import vllm.model_executor.layers.fla.ops.chunk +import vllm.model_executor.layers.fla.ops.fused_recurrent +import vllm.model_executor.layers.fla.ops.layernorm_guard +import vllm.model_executor.layers.mamba.ops.causal_conv1d + +from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, + causal_conv1d_update_npu) +from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule +from vllm_ascend.ops.sigmoid_gating import \ + fused_recurrent_gated_delta_rule_fwd_kernel + +vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu +vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn +vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel +vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn +vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule \ No newline at end of file diff --git a/vllm_ascend/patch/worker/patch_common/patch_weight_loader.py b/vllm_ascend/patch/worker/patch_common/patch_weight_loader.py new file mode 100644 index 0000000..10705d3 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_weight_loader.py @@ -0,0 +1,44 @@ +import torch +from torch.nn.parameter import Parameter +from vllm.logger import init_logger +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import GiB_bytes + +from vllm_ascend.utils import vllm_version_is + +logger = init_logger(__name__) + + +def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + try: + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + except torch.cuda.OutOfMemoryError as e: + logger.error("Failed to create unquantized linear weights: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug("Allocated: %.2f GiB", + torch.cuda.memory_allocated() / GiB_bytes) + logger.debug("Reserved: %.2f GiB", + torch.cuda.memory_reserved() / GiB_bytes) + raise RuntimeError( + "Failed to create unquantized linear weights. " + "This may be caused by insufficient memory to allocate " + "the weight.") from e + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + +if not vllm_version_is("0.10.2"): + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + UnquantizedLinearMethod.create_weights = create_weights diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 57ace2b..f1581df 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -16,6 +16,7 @@ # import gc +import os from datetime import timedelta from typing import TYPE_CHECKING, Optional, Tuple @@ -31,7 +32,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p, - update_aclgraph_sizes) + update_aclgraph_sizes, vllm_version_is) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -128,11 +129,43 @@ class NPUPlatform(Platform): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config + ascend_scheduler_config = ascend_config.ascend_scheduler_config + if vllm_version_is("0.10.2"): + structured_outputs_config = vllm_config.decoding_config + else: + structured_outputs_config = vllm_config.structured_outputs_config + + if (model_config is not None and not model_config.use_mla + and not scheduler_config.async_scheduling): + logger.info( + "Non-MLA LLMs forcibly disable the chunked prefill feature," + "as the performance of operators supporting this feature " + "functionality is currently suboptimal.") + if not model_config.is_multimodal_model and \ + structured_outputs_config.backend == "auto" and \ + not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ + not scheduler_config.send_delta_data and \ + scheduler_config.policy == "fcfs": + ascend_scheduler_config.enabled = True + chunked_prefill_enabled_in_ascend_scheduler = getattr( + ascend_scheduler_config, "enable_chunked_prefill", False) + if chunked_prefill_enabled_in_ascend_scheduler: + logger.warning( + "Chunked prefill feature is enabled in ascend_scheduler," + "but note that the operator supporting this feature " + "would lead to performance degradation.") + # In this situation, max_num_batched_tokens would have been rewritten. + # So we must make sure max_num_batched_tokens is not smaller than max_model_len. + if (scheduler_config.max_num_batched_tokens + < scheduler_config.max_model_len + and not chunked_prefill_enabled_in_ascend_scheduler): + scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len + kv_cache_dtype = vllm_config.additional_config.get( "kv_cache_dtype", None) if kv_cache_dtype is not None: vllm_config.cache_config.cache_dtype = kv_cache_dtype - if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") @@ -148,23 +181,13 @@ class NPUPlatform(Platform): compilation_config.cudagraph_num_of_warmups = 1 - # TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode` - # if cudagraph_mode is not explicitly set by users, set default value - if compilation_config.level == CompilationLevel.PIECEWISE: - compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - elif compilation_config.level not in [ + if compilation_config.level not in [ CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE ]: logger.warning( "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", compilation_config.level) compilation_config.cudagraph_mode = CUDAGraphMode.NONE - else: - logger.warning( - "compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE" - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. if ascend_config.torchair_graph_config.enabled: @@ -185,18 +208,22 @@ class NPUPlatform(Platform): "and use_cached_kv_cache_bytes in torchair_graph_config.") delete_torchair_cache_file() - if parallel_config.distributed_executor_backend == "ray": - logger.warning( - "Ray distributed executor backend is not compatible with ACL Graph mode " - "right now. Setting CUDAGraphMode to NONE") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # set cudaprah sizes before extending `compilation_config.splitting_ops` vllm_config._set_cudagraph_sizes() + # TODO: Full graph is fully supported later, and the default value will be set to full graph. + if not vllm_version_is("0.10.2"): + if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.level = CompilationLevel.NO_COMPILATION - elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: + # TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition + # after MLA being supported + elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or ( + compilation_config.cudagraph_mode + == CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None + and model_config.use_mla): logger.info( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") @@ -204,9 +231,28 @@ class NPUPlatform(Platform): "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" compilation_config.set_splitting_ops_for_v1() compilation_config.use_inductor = False - compilation_config.splitting_ops.extend( - ["vllm.unified_ascend_attention_with_output"]) + compilation_config.splitting_ops.extend([ + "vllm.unified_ascend_attention_with_output", "vllm.mla_forward" + ]) update_aclgraph_sizes(vllm_config) + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + logger.info( + "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " + "using only ACL Graph mode") + compilation_config.use_inductor = False + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + logger.warning(warning_message) else: logger.info( "%s cudagraph_mode is not support on NPU. falling back to NONE", @@ -215,7 +261,9 @@ class NPUPlatform(Platform): compilation_config.level = CompilationLevel.NO_COMPILATION if parallel_config and parallel_config.worker_cls == "auto": - if ascend_config.torchair_graph_config.enabled: + # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. + os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv" + if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" @@ -223,6 +271,7 @@ class NPUPlatform(Platform): if cache_config: if cache_config.block_size is None: cache_config.block_size = 128 + if cache_config.enable_prefix_caching and cache_config.block_size != 128: logger.warning( "If prefix caching is enabled, block size must be set to 128." @@ -242,12 +291,6 @@ class NPUPlatform(Platform): ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config - if compilation_config.pass_config.enable_sequence_parallelism: - if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe": - raise NotImplementedError( - "For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV." - ) - @classmethod def get_attn_backend_cls(cls, selected_backend, @@ -257,27 +300,40 @@ class NPUPlatform(Platform): block_size, use_v1, use_mla, + use_sfa, has_sink=False): if not use_v1: raise ValueError("vLLM Ascend does not support V0 engine.") - use_torchair = get_ascend_config().torchair_graph_config.enabled + ascend_config = get_ascend_config() + + if use_mla and ascend_config.enable_shared_expert_dp: + if use_mla and not use_sfa: + return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" + if use_mla and use_sfa: + return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" + + use_torchair = ascend_config.torchair_graph_config.enabled # choose attention backend based on use_mla and use_torchair backend_map = { - (True, True): + (True, False, True): "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend", - (True, False): + (True, False, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", - (False, True): + (False, False, True): "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend", - (False, False): - "vllm_ascend.attention.attention_v1.AscendAttentionBackend" + (False, False, False): + "vllm_ascend.attention.attention_v1.AscendAttentionBackend", + (True, True, False): + "vllm_ascend.attention.sfa_v1.AscendSFABackend", + (True, True, True): + "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend", } - return backend_map[(use_mla, use_torchair)] + return backend_map[(use_mla, use_sfa, use_torchair)] @classmethod def get_punica_wrapper(cls) -> str: - return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU" + return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU" @classmethod def get_current_memory_usage(cls, @@ -343,3 +399,11 @@ class NPUPlatform(Platform): pg._register_backend(device, backend_type, backend_class) return pg + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + + @classmethod + def support_static_graph_mode(cls) -> bool: + return True diff --git a/vllm_ascend/quantization/func_wrapper.py b/vllm_ascend/quantization/func_wrapper.py deleted file mode 100644 index 8357695..0000000 --- a/vllm_ascend/quantization/func_wrapper.py +++ /dev/null @@ -1,184 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Optional, Tuple, Union - -import torch -import torch_npu -from vllm.logger import logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig) - - -# func refers to vocabParallelEmbedding.__init__ -def wrapper_vocab_parallel_embedding_init(func): - - def init( - self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - func( - self, - num_embeddings, - embedding_dim, - params_dtype, - org_num_embeddings, - padding_size, - quant_config, - prefix, - ) - # TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class. - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - return init - - -# func refers to RMSNorm.__init__ -def wrapper_rmsnorm_init(func): - - def init(self, hidden_size: int, **extra_args) -> None: - func(self, hidden_size, **extra_args) - self.ignore_anti = True - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - - return init - - -# func refers to RMSNorm.forward_oot -def wrapper_rmsnorm_forward_oot(func): - - def _rmsnorm_forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if not self.ignore_anti: - if residual is not None: - residual += x - out = torch_npu._npu_quant_rms_norm( - residual, - self.weight, - self.bias, - self.input_scale, - self.input_offset, - self.variance_epsilon, - ) - return out, residual - out = torch_npu._npu_quant_rms_norm( - x, - self.weight, - self.bias, - self.input_scale, - self.input_offset, - self.variance_epsilon, - ) - return out - - if residual is not None: - x, residual = func(self, x, residual) - return x.add_(self.bias), residual - - return func(self, x).add_(self.bias) - - return _rmsnorm_forward_oot - - -MODEL_LAYER_MAPPING = { - "LlamaModel": { - "attn": { - "layer_attr": "self_attn", - "proj_attr": "qkv_proj", - "norm_attr": "input_layernorm", - "unquantized_type": UnquantizedLinearMethod, - }, - "mlp": { - "layer_attr": "mlp", - "proj_attr": "gate_up_proj", - "norm_attr": "post_attention_layernorm", - "unquantized_type": UnquantizedLinearMethod, - }, - }, -} - - -def wrapper_load_model(func): - - def postprocess_loading(self) -> None: - func(self) - - def process_layer(layer, idx, mapping): - - def process_module(module_cfg, layer_obj): - if module_cfg is None: - return - - module_obj = getattr(layer_obj, module_cfg["layer_attr"], None) - if module_obj is None: - return - - proj_attr = module_cfg["proj_attr"] - if callable(proj_attr): - proj = proj_attr(module_obj, idx) - else: - proj = getattr(module_obj, proj_attr, None) - - norm = getattr(layer_obj, module_cfg["norm_attr"], None) - - if proj is None or norm is None: - return - - norm.ignore_anti = isinstance(proj.quant_method, - module_cfg["unquantized_type"]) - if not norm.ignore_anti: - for param_name in ["input_scale", "input_offset"]: - if hasattr(proj, param_name): - param = getattr(proj, param_name) - norm.register_parameter( - param_name, - torch.nn.Parameter(param.clone(), - requires_grad=False)) - - process_module(mapping.get("attn"), layer) - process_module(mapping.get("mlp"), layer) - - model_type = self.model.model.__class__.__name__ - mapping = MODEL_LAYER_MAPPING.get(model_type) - - if not mapping: - logger.info( - f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping." - ) - return - - for idx, layer in enumerate(self.model.model.layers): - process_layer(layer, idx, mapping) - - if isinstance(self.model.model.norm, RMSNorm): - self.model.model.norm.ignore_anti = True - - return postprocess_loading diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index d449c8d..130251c 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -19,6 +19,7 @@ from types import MappingProxyType from typing import Any, Callable, Dict, List, Mapping, Optional import torch +from vllm.config import get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -32,13 +33,15 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod, VocabParallelEmbedding) -from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs +from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, + get_otp_group) from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, + oproj_tp_enable) -from .quantizer import AscendQuantizer +from .utils import get_quant_method @register_quantization_config(ASCEND_QUANTIZATION_METHOD) @@ -50,6 +53,7 @@ class AscendQuantConfig(QuantizationConfig): """ def __init__(self, quant_config: Dict[str, Any]): + super().__init__() self.quant_description = quant_config def __repr__(self) -> str: @@ -85,7 +89,14 @@ class AscendQuantConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: + vllm_config = get_current_vllm_config() + model_type = vllm_config.model_config.hf_config.model_type + if model_type in packed_modules_model_mapping: + self.packed_modules_mapping = packed_modules_model_mapping[ + model_type] from vllm.attention.layer import Attention + if prefix.startswith("language_model"): + prefix = prefix.split('.', 1)[-1] if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): @@ -147,21 +158,86 @@ class AscendQuantConfig(QuantizationConfig): return [] +packed_modules_model_mapping = { + "qwen3_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + }, + "deepseek_v2": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "deepseek_v3": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + "deepseek_mtp": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "qwen3_next": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj": ["in_proj_qkvz", "in_proj_ba"], + }, + "qwen2_5_vl": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + }, + "glm4_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, +} + + class AscendLinearMethod(LinearMethodBase): """Linear method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for linear methods. - Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]) -> None: - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_linear_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "linear", + packed_modules_mapping) def create_weights( self, @@ -174,7 +250,6 @@ class AscendLinearMethod(LinearMethodBase): **extra_weight_attrs, ) -> None: output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, @@ -187,8 +262,7 @@ class AscendLinearMethod(LinearMethodBase): pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): - param = PerTensorScaleParameter(data=pertensor_param, - weight_loader=weight_loader) + param = torch.nn.Parameter(pertensor_param, requires_grad=False) # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param) @@ -223,25 +297,27 @@ class AscendLinearMethod(LinearMethodBase): bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if isinstance(layer, RowParallelLinear): - tp_rank = get_tensor_model_parallel_rank() - return self.quant_method.apply(layer, x, bias, tp_rank) - return self.quant_method.apply(layer, x, bias) + if layer.prefix.find("o_proj") != -1 and oproj_tp_enable(): + tp_rank = get_otp_group().rank_in_group + elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): + tp_rank = get_mlp_tp_group().rank_in_group + else: + tp_rank = get_tensor_model_parallel_rank() + else: + tp_rank = 0 + return self.quant_method.apply(layer, x, bias, tp_rank) class AscendKVCacheMethod(BaseKVCacheMethod): """KVCache method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for kvcache methods. - Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix) - self.quant_method = self.quantizer.build_attention_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "attention") def create_weights(self, layer: torch.nn.Module) -> None: # Different from linear method, there are no weight processing/slicing @@ -263,18 +339,15 @@ class AscendKVCacheMethod(BaseKVCacheMethod): class AscendFusedMoEMethod(FusedMoEMethodBase): """FusedMoE method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for kvcache methods. - Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]): - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_moe_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "moe", + packed_modules_mapping) def create_weights( self, @@ -341,17 +414,20 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) + def get_fused_moe_quant_config(self, layer: torch.nn.Module): + # TODO: implement this function + pass + class AscendEmbeddingMethod(AscendLinearMethod): """Embedding method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for Embedding methods. + Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]) -> None: - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_linear_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "linear", + packed_modules_mapping) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py deleted file mode 100644 index 0e15ed2..0000000 --- a/vllm_ascend/quantization/quantizer.py +++ /dev/null @@ -1,311 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import importlib -import sys -import types -from typing import Any, Dict, List, Optional - -from vllm.logger import logger - -from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init, - wrapper_vocab_parallel_embedding_init) -from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, - AscendW4A8DynamicLinearMethod) -from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, - AscendW8A8LinearMethod) -from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod) - -CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] - - -class AscendQuantizer: - """An interface to different quantization implementations for ascend hardwares.""" - - @classmethod - def get_quantizer(cls, - quant_config: Dict[str, Any], - prefix: str, - packed_modules_mapping: Optional[Dict[str, - Any]] = dict()): - # TODO: Need a param to choose quantization algorithms. - quantization_algorithm = '' - - if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE: - return - - return VLLMAscendQuantizer.get_quantizer(quant_config, prefix, - packed_modules_mapping) - - def build_linear_method(self): - raise NotImplementedError - - def build_moe_method(self): - raise NotImplementedError - - def build_attention_method(self): - raise NotImplementedError - - -class VLLMAscendQuantizer: - _instance: Optional[object] = None - patched = False - - def __init__(self, quant_description): - if VLLMAscendQuantizer.patched: - return - for name in quant_description.keys(): - if "norm.bias" in name: - VLLMAscendQuantizer.apply_patch( - "vllm.model_executor.layers.layernorm.RMSNorm", "__init__", - [wrapper_rmsnorm_init]) - VLLMAscendQuantizer.apply_patch( - "vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot", - [wrapper_rmsnorm_forward_oot]) - VLLMAscendQuantizer.apply_patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding", - "__init__", [wrapper_vocab_parallel_embedding_init]) - break - VLLMAscendQuantizer.patched = True - logger.info("Using the vLLM Ascend Quantizer version now!") - - @staticmethod - def apply_patch(target_module, target_function, wrappers): - - original_module, original_function = VLLMAscendQuantizer.parse_path( - target_module, target_function, False) - - original_function_id = id(original_function) - - candidate = original_function - for wrapper in wrappers: - candidate = wrapper(candidate) - if target_function is not None: - setattr(original_module, target_function, candidate) - - for _, value in sys.modules.copy().items(): - if target_function is None: - continue - try: - attr = getattr(value, target_function, None) - if attr is not None and id(attr) == original_function_id: - setattr(value, target_function, candidate) - except ImportError: - continue - - @staticmethod - def parse_path(module_path, function_name, create_dummy): - """ - Parse module path and resolve/create modules as needed. - - Args: - module_path: Dot-separated module path - function_name: Target function name (None for module only) - create_dummy: Create dummy modules/functions when missing - - Returns: - Tuple of (resolved module, target function/none) - - Raises: - ModuleNotFoundError: If module path is invalid and create_dummy=False - AttributeError: If function is missing and create_dummy=False - """ - from importlib.machinery import ModuleSpec - - def create_dummy_module(full_path, parent=None): - """Create and register a placeholder module""" - dummy = types.ModuleType(full_path) - dummy.__file__ = "vllm_ascend.dummy_module.py" - dummy.__spec__ = ModuleSpec(full_path, None) - sys.modules[full_path] = dummy - if parent: - setattr(parent, full_path.split(".")[-1], dummy) - return dummy - - def create_placeholder_function(func_name): - """Create dummy function that raises when called""" - - def placeholder(*args, **kwargs): - raise NotImplementedError( - f"Function {func_name} is a placeholder") - - placeholder.__name__ = func_name - return placeholder - - modules = module_path.split(".") - current_module = None - processed_path = [] - - for idx, part in enumerate(modules): - current_path = ".".join(modules[:idx + 1]) - parent_path = ".".join(modules[:idx]) if idx > 0 else None - - try: - current_module = importlib.import_module(current_path) - except ModuleNotFoundError: - # Handle missing module - parent = importlib.import_module( - parent_path) if parent_path else None - if parent and hasattr(parent, part): - # Use existing attribute from parent - current_module = getattr(parent, part) - # Check for early function resolution - if function_name and hasattr(current_module, - function_name): - return current_module, getattr(current_module, - function_name) - if function_name and create_dummy: - ph_func = create_placeholder_function(function_name) - setattr(current_module, function_name, ph_func) - return current_module, ph_func - if function_name: - raise AttributeError( - f"Function {function_name} missing in {current_path}" - ) - else: - if not create_dummy: - raise - # Create and register dummy module - current_module = create_dummy_module( - current_path, - parent=importlib.import_module(parent_path) - if parent_path else None) - - processed_path.append(part) - - # Final function handling - final_module = sys.modules[module_path] - if function_name is not None: - if not hasattr(final_module, function_name): - if create_dummy: - ph_func = create_placeholder_function(function_name) - setattr(final_module, function_name, ph_func) - else: - setattr(final_module, function_name, None) - return final_module, getattr(final_module, function_name) - - return final_module, None - - @staticmethod - def build_linear_method(): - raise NotImplementedError( - "Linear method is not implemented for the current quant type.") - - @staticmethod - def build_moe_method(): - raise NotImplementedError( - "MoE method is not implemented for the current quant type.") - - @staticmethod - def build_attention_method(): - raise NotImplementedError( - "Attention method is not implemented for the current quant type.") - - @staticmethod - def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]): - proj_name = prefix.split(".")[-1] - if proj_name in packed_modules_mapping: - quant_type = None - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in packed_modules_mapping[proj_name] - ] - for shard_prefix in shard_prefixes: - shard_quant_type = quant_description[shard_prefix + '.weight'] - - if quant_type is None: - quant_type = shard_quant_type - elif shard_quant_type != quant_type: - raise ValueError( - f"Not all shards of {prefix} are quantized with same quant type." - f"Shard {proj_name} uses {shard_quant_type}, but another shard" - f"use {quant_type}. Please check quantization config.") - else: - quant_type = quant_description[prefix + '.weight'] - return quant_type - - @classmethod - def get_quantizer(cls, - quant_description: Dict[str, Any], - prefix: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - if packed_modules_mapping is None: - packed_modules_mapping = dict() - # Attention - if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): - quant_type = quant_description['fa_quant_type'] - # Use KVCache int8 - elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): - quant_type = quant_description['kv_quant_type'] - # Linear - else: - quant_type = cls.get_linear_quant_type(quant_description, prefix, - packed_modules_mapping) - if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys(): - cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type] - if not cls._instance: - cls._instance = cls(quant_description) - return cls._instance - raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ - f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") - - -class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return AscendW4A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return AscendW4A8DynamicFusedMoEMethod() - - -class W8A8Quantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return AscendW8A8LinearMethod() - - @staticmethod - def build_moe_method(): - return AscendW8A8FusedMoEMethod() - - @staticmethod - def build_attention_method(): - return AscendC8KVCacheMethod() - - -class W8A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return AscendW8A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return AscendW8A8DynamicFusedMoEMethod() - - -SUPPORT_ASCEND_QUANTIZER_TYPE = { - "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, - "W8A8": W8A8Quantizer, - "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, - "C8": W8A8Quantizer, -} diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py new file mode 100644 index 0000000..dc5845a --- /dev/null +++ b/vllm_ascend/quantization/utils.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, Optional, Type + +from vllm.logger import logger + +from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) +from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, + AscendW8A8LinearMethod) +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) + +ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { + "W4A8_DYNAMIC": { + "linear": AscendW4A8DynamicLinearMethod, + "moe": AscendW4A8DynamicFusedMoEMethod, + }, + "W8A8": { + "linear": AscendW8A8LinearMethod, + "moe": AscendW8A8FusedMoEMethod, + "attention": AscendC8KVCacheMethod, + }, + "W8A8_DYNAMIC": { + "linear": AscendW8A8DynamicLinearMethod, + "moe": AscendW8A8DynamicFusedMoEMethod, + }, + "C8": { + "attention": AscendC8KVCacheMethod, + }, +} + + +def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, + packed_modules_mapping: Dict[str, Any]): + proj_name = prefix.split(".")[-1] + if proj_name in packed_modules_mapping: + quant_type = None + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in packed_modules_mapping[proj_name] + ] + for shard_prefix in shard_prefixes: + shard_quant_type = quant_description[shard_prefix + '.weight'] + + if quant_type is None: + quant_type = shard_quant_type + elif shard_quant_type != quant_type: + raise ValueError( + f"Not all shards of {prefix} are quantized with same quant type." + f"Shard {proj_name} uses {shard_quant_type}, but another shard" + f"use {quant_type}. Please check quantization config.") + else: + quant_type = quant_description[prefix + '.weight'] + return quant_type + + +def get_quant_method(quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + logger.info_once("Using the vLLM Ascend Quantization now!") + if packed_modules_mapping is None: + packed_modules_mapping = dict() + # Attention + if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): + quant_type = quant_description['fa_quant_type'] + # Use KVCache int8 + elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): + quant_type = quant_description['kv_quant_type'] + # Linear + else: + quant_type = get_linear_quant_type(quant_description, prefix, + packed_modules_mapping) + if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys(): + method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type] + if layer_type in method_map.keys(): + method_cls = method_map[layer_type] + return method_cls() + else: + raise NotImplementedError( + f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." + ) + raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ + f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 72f956d..b8bcc78 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -24,10 +24,10 @@ from vllm.config import get_current_vllm_config from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import unified_fused_experts_eager -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ class AscendW4A8DynamicLinearMethod: @@ -133,11 +133,14 @@ class AscendW4A8DynamicFusedMoEMethod: vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 256) + # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process + self.is_per_channel_weight = self.group_size == 0 quant_version = vllm_config.quant_config.quant_description.get( "version", "0") # NOTE: new quantize weights: 2 int4 pack into int8 self.new_quant_version = quant_version == "1.0.0" self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size + self.dynamic_eplb = get_ascend_config().dynamic_eplb if self.new_quant_version and self.tp_size > 16: raise ValueError( "The current weight does not support moe part tp>16.") @@ -182,44 +185,44 @@ class AscendW4A8DynamicFusedMoEMethod: num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w13_weight_offset"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) - - param_dict["w13_weight_scale_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) - - param_dict["w13_weight_offset_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) - param_dict["w2_weight_scale_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) - param_dict["w2_weight_offset_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) + dtype=torch.float32) + if not self.is_per_channel_weight: + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) if self.new_quant_version: param_dict["w13_scale_bias"] = torch.empty( @@ -275,14 +278,6 @@ class AscendW4A8DynamicFusedMoEMethod: e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -291,27 +286,36 @@ class AscendW4A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) - return unified_fused_experts_eager( + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, row_idx=row_idx, + use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None), - with_quant=True) + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + dynamic_eplb=self.dynamic_eplb) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + scale = scale.transpose(1, 2).contiguous() + if self.is_per_channel_weight: + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() group_num, k, n = weight.shape # the weight of the new version is reduced by half by pack n, so it needs to be restored if self.new_quant_version: @@ -354,13 +358,10 @@ class AscendW4A8DynamicFusedMoEMethod: def pack_to_int32(self, weight: torch.Tensor): if self.new_quant_version: - group_num, k, n = weight.shape - assert n % 4 == 0, "the last dim of weight needs to be divided by 4" - packed_n = n // 4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 - packed_weight = torch.from_numpy( - np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) - return packed_weight.reshape(group_num, k, packed_n).npu() + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + return weight.view(torch.int32).contiguous() else: return torch_npu.npu_quantize(weight.to(torch.float32), torch.tensor([1.]).npu(), None, @@ -372,23 +373,29 @@ class AscendW4A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( - 1, 2).contiguous() - layer.w13_weight_scale_second.data, w13_bias = self.process_scale( + w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( + layer, "w13_weight_scale_second") else None + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( + layer, "w2_weight_scale_second") else None + layer.w13_weight_scale.data, w13_bias = self.process_scale( layer.w13_weight, layer.w13_weight_scale.data, - layer.w13_weight_scale_second.data) - layer.w2_weight_scale_second.data, w2_bias = self.process_scale( + w13_weight_scale_second) + layer.w2_weight_scale.data, w2_bias = self.process_scale( layer.w2_weight, layer.w2_weight_scale.data, - layer.w2_weight_scale_second.data) + w2_weight_scale_second) + if hasattr(layer, "w13_weight_scale_second"): + # scale_second is no longer used, release this part of the memory + del layer.w13_weight_scale_second + del layer.w2_weight_scale_second + del layer.w13_weight_offset_second + del layer.w2_weight_offset_second self.update_bias(layer, w13_bias, w2_bias) + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index e4cbdc8..010d45d 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 20c68be..ab4987f 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -23,181 +23,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.common_fused_moe import \ - fused_experts as unified_fused_experts -from vllm_ascend.ops.fused_moe import unified_fused_experts_eager -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor - - -def apply_mlp_decode(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - Args: - hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - w2_scale: weights2 scale with shape (num_experts, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - Returns: - hidden_states: output hidden states after MLP. - """ - - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=w2_scale.dtype)[0] - return hidden_states - - -def apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - - Args: - hidden_states: input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - w2_scale: weights2 scale with shape (num_experts, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - - Returns: - hidden_states: output hidden states after MLP. - """ - - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - bias1, bias2 = None, None - _output_dtype = w2_scale.dtype - - if w1_scale_bias is not None: - if group_list_type == 0: - group_list = torch.cat( - [group_list[:1], torch.diff(group_list, dim=0)]) - group_list_type = 1 - bias1 = [w1_scale_bias] - bias2 = [w2_scale_bias] - # TODO w4a8 scene: dynamic acquisition of dtype in the future - _output_dtype = torch.bfloat16 - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - scale=[w1_scale], - bias=bias1, - per_token_scale=[pertoken_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - bias=bias2, - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - return hidden_states +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ class AscendW8A8DynamicLinearMethod: @@ -271,8 +100,9 @@ class AscendW8A8DynamicLinearMethod: def process_weights_after_loading(self, layer): if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - # cast quantized weight tensors in NZ format (29) for higher inference speed - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + # cast quantized weight tensors in NZ format for higher inference speed + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, + ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -293,6 +123,7 @@ class AscendW8A8DynamicFusedMoEMethod: vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager and not ascend_config.torchair_graph_config.enabled) + self.dynamic_eplb = ascend_config.dynamic_eplb try: device_group = get_mc2_group().device_group @@ -387,25 +218,19 @@ class AscendW8A8DynamicFusedMoEMethod: global_num_experts=global_num_experts) if self.use_aclgraph: - return unified_fused_experts( + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, use_int8_w8a8=True, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, expert_map=expert_map, - ) - - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] + dynamic_eplb=self.dynamic_eplb) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. @@ -415,23 +240,24 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) - return unified_fused_experts_eager( + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, + w1_scale=layer.w13_weight_scale_fp32, w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, row_idx=row_idx, + use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None), - with_quant=True) + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + dynamic_eplb=self.dynamic_eplb) def process_weights_after_loading(self, layer): if self.transpose_weight: @@ -439,8 +265,8 @@ class AscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP: - torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( diff --git a/vllm_ascend/sample/logits_processor/__init__.py b/vllm_ascend/sample/logits_processor/__init__.py new file mode 100644 index 0000000..5f810bf --- /dev/null +++ b/vllm_ascend/sample/logits_processor/__init__.py @@ -0,0 +1,50 @@ +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING, Union + +import torch +from vllm.logger import init_logger +from vllm.v1.sample import logits_processor +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import LogitsProcessor +from vllm.v1.sample.logits_processor.state import LogitsProcessors + +from vllm_ascend.sample.logits_processor.builtin import \ + AscendMinPLogitsProcessor + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + AscendMinPLogitsProcessor, +] + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs( + custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) diff --git a/vllm_ascend/sample/logits_processor/builtin.py b/vllm_ascend/sample/logits_processor/builtin.py new file mode 100644 index 0000000..f38d940 --- /dev/null +++ b/vllm_ascend/sample/logits_processor/builtin.py @@ -0,0 +1,35 @@ +import torch +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import MinPLogitsProcessor + + +class AscendMinPLogitsProcessor(MinPLogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory) + + decode_max_num_seqs = getattr(vllm_config.scheduler_config, + 'decode_max_num_seqs', 0) + if decode_max_num_seqs != 0: + max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, + decode_max_num_seqs) + + self.min_p_count: int = 0 + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=is_pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + + self.use_double_tensor = torch.device(device).type != "cpu" + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs, ), dtype=torch.float32, device=device) + else: + self.min_p_device = self.min_p_cpu_tensor + # Current slice of the device tensor + self.min_p: torch.Tensor = self.min_p_device[:0] diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index b5a212a..9cceda6 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -5,11 +5,10 @@ from vllm.v1.sample.sampler import Sampler from vllm_ascend.utils import is_310p, vllm_version_is -if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): +if vllm_version_is("0.10.2"): from vllm.config import LogprobsMode DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS else: - LogprobsMode = None DEFAULT_LOGPROBS_MODE = "raw_logprobs" @@ -30,7 +29,8 @@ class AscendTopKTopPSampler(TopKTopPSampler): p: torch.Tensor, ) -> torch.Tensor: # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P - if not is_310p() and p is not None and k is not None: + if not is_310p() and p is not None and k is not None and 1 <= int( + k.max()) <= 1024: # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) return torch_npu.npu_top_k_top_p(logits, p, k) @@ -68,19 +68,19 @@ class AscendTopKTopPSampler(TopKTopPSampler): def forward_native(self, logits, generators, k, p): """Override pytorch native implementation to torch_npu""" logits = self._apply_top_k_top_p(logits, k, p) - if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): - - logits_to_return = None + logits_to_return = None + if vllm_version_is("0.10.2"): if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: logits_to_return = logits elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) + else: + if self.logprobs_mode == "processed_logits": + logits_to_return = logits + elif self.logprobs_mode == "processed_logprobs": + logits_to_return = logits.log_softmax(dim=-1, + dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - output = None - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - output = random_sample(probs, generators) - else: - output = (random_sample(probs, generators), logits_to_return) - return output + return random_sample(probs, generators), logits_to_return diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py new file mode 100644 index 0000000..64076c2 --- /dev/null +++ b/vllm_ascend/spec_decode/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py +# +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.mtp_proposer import MtpProposer +from vllm_ascend.spec_decode.ngram_proposer import NgramProposer + + +def get_spec_decode_method(method, vllm_config, device, runner): + if method == "ngram": + return NgramProposer(vllm_config, device, runner) + elif method in ["eagle", "eagle3"]: + return EagleProposer(vllm_config, device, runner) + elif method == 'deepseek_mtp': + return MtpProposer(vllm_config, device, runner) + else: + raise ValueError("Unknown speculative decoding method: " + f"{method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py new file mode 100644 index 0000000..d14dc6d --- /dev/null +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -0,0 +1,674 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from vllm.attention.layer import Attention +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType +from vllm_ascend.utils import vllm_version_is + +PADDING_SLOT_ID = -1 + + +class EagleProposer(Proposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3 + self.vllm_config = vllm_config + self.device = device + self.runner = runner + + self.block_size = vllm_config.cache_config.block_size + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size( + ) + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=device) + self.positions = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (self.vllm_config.scheduler_config.max_num_batched_tokens, + self.hidden_size), + dtype=self.vllm_config.model_config.dtype, + device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=device, + dtype=torch.int32) + attn_mask_len = self.vllm_config.model_config.max_model_len + self.attn_mask_builder = AttentionMaskBuilder( + attn_mask_len, self.vllm_config.model_config.dtype) + + def load_model(self, model: nn.Module) -> None: + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + self.attn_layer_name = next(iter(draft_attn_layer_names)) + + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1: + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = model.model.embed_tokens + else: + logger.info( + "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"): + logger.info("Loading EAGLE LM head weights from the target model.") + if supports_multimodal(model): + self.model.lm_head = model.get_language_model().lm_head + else: + self.model.lm_head = model.lm_head + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None): + moe_comm_type = self.runner._select_moe_comm_method( + num_tokens, with_prefill) + with set_ascend_forward_context(None, + self.vllm_config, + moe_comm_type=moe_comm_type, + num_tokens=num_tokens): + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + + attn_metadata = self._get_eagle_atten_dict(scheduler_output) + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + eagle_attn_metadata = attn_metadata[self.attn_layer_name] + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc + else: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + cu_num_tokens, token_indices = self._prepare_inputs( + eagle_attn_metadata.query_start_loc, num_rejected_tokens, + num_tokens) + target_token_ids = self.runner.input_ids[token_indices] + target_positions = positions[token_indices] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] + + draft_token_ids = self._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=eagle_attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _get_eagle_atten_dict( + self, + scheduler_output: "SchedulerOutput", + ): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.runner.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.runner.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.runner.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + self.runner.query_lens = torch.from_numpy(num_scheduled_tokens) + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.runner.arange_np[:num_reqs], + num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + + # Get positions. + positions_np = self.runner.positions_np[:total_num_scheduled_tokens] + np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.runner.uses_mrope: + self.runner._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = ( + positions_np + + req_indices * self.runner.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.runner.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.runner.input_ids_cpu[:total_num_scheduled_tokens]) + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.runner.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table = self.runner.input_batch.block_table[ + kv_cache_group_id] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.runner.query_start_loc_np[0] = 0 + self.runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + + self.runner.seq_lens_np[:num_reqs] = ( + self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + + # Copy the tensors to the NPU. + self.runner.input_ids[:total_num_scheduled_tokens].copy_( + self.runner.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + if self.runner.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.runner.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.runner. + mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.runner.positions[:total_num_scheduled_tokens].copy_( + self.runner.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + self.runner.query_start_loc[:num_reqs + 1].copy_( + self.runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.runner.seq_lens[:num_reqs].copy_( + self.runner.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.runner.seq_lens[num_reqs:].fill_(0) + self.runner.query_start_loc[num_reqs + 1:].fill_(-1) + + attn_metadata = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.runner.kv_cache_config.kv_cache_groups): + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + num_reqs=num_reqs, + max_query_len=max_num_scheduled_tokens, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=self.runner.slot_mapping, + positions=self.runner.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + if vllm_version_is("0.10.2"): + builder = self.runner.attn_groups[0][0].metadata_builder + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_i = builder.build(0, common_attn_metadata, + self.runner.get_model()) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + return attn_metadata + + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.runner.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def _propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + device = cu_num_tokens.device + cu_num_tokens = cu_num_tokens.cpu() + block_table = block_table.cpu() + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + target_positions = target_positions.cpu() + if self.name == SpecDcodeType.EAGLE3: + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + seq_lens = (target_positions[last_token_indices] + 1).int() + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, target_positions, self.vllm_config.model_config.dtype, + self.device) + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens.to(device), + query_start_loc_cpu=cu_num_tokens, + seq_lens_cpu=seq_lens.cpu(), + max_query_len=max_query_len, + num_reqs=batch_size, + num_actual_tokens=num_tokens, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=target_slot_mapping, + positions=target_positions, + attn_mask=attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + if vllm_version_is("0.10.2"): + builder = self.runner.attn_groups[0][0].metadata_builder + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata = builder.build(0, common_attn_metadata, + self.runner.get_model()) + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + + with_prefill = attn_metadata.attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + moe_comm_type = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions.to(device) + self.hidden_states[:num_tokens] = target_hidden_states + attn_metadata.block_tables = block_table.to(device) + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + moe_comm_type=moe_comm_type, + num_tokens=num_input_tokens): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + ) + sample_hidden_states = last_hidden_states[last_token_indices] + if vllm_version_is("0.10.2"): + logits = self.model.compute_logits(sample_hidden_states, None) + else: + logits = self.model.compute_logits(sample_hidden_states) + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.vllm_config.speculative_config.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + # Generate the remaining draft tokens. + draft_token_ids_tensor = torch.zeros( + (self.vllm_config.speculative_config.num_speculative_tokens, + *draft_token_ids.shape), + dtype=draft_token_ids.dtype) + draft_token_ids_tensor[0] = draft_token_ids + + positions_cpu = target_positions[last_token_indices].cpu().to( + torch.int64) + hidden_states = hidden_states[last_token_indices] + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size + + moe_comm_type = self.runner._select_moe_comm_method( + input_batch_size, False) + + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size + 1] + query_lens.fill_(1) + attn_metadata.query_lens = query_lens + + attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + for now_speculative in range( + self.vllm_config.speculative_config.num_speculative_tokens - + 1): + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_tensor[now_speculative].to(device) + positions_cpu += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions_cpu >= self.vllm_config.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, + positions_cpu) + clamped_positions = clamped_positions_cpu.to(device) + + # TODO: Increment the sequence lengths. + + attn_metadata.seq_lens += 1 + # TODO: Consider max model length. + # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + # self.max_model_len) + # For the requests that exceed the max model length, we set the + # TODO: sequence length to 1 to minimize their overheads in attention. + + # Compute the slot mapping. + block_numbers = (clamped_positions_cpu // self.block_size) + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + slot_mapping_cpu = ( + block_ids * self.vllm_config.cache_config.block_size + + clamped_positions_cpu % self.block_size) + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping_cpu.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + # NOTE: ASCEND slot_mapping must on cpu + attn_metadata.slot_mapping = slot_mapping_cpu.to( + torch.int32).to(device) + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + attn_metadata.seq_lens, positions_cpu, + self.vllm_config.model_config.dtype, self.device) + + attn_metadata.attn_mask = attn_mask + attn_metadata.block_tables = block_table.to(device) + # Run the model. + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + moe_comm_type=moe_comm_type, + num_tokens=input_batch_size): + + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + ) + hidden_states = hidden_states[:batch_size] + if vllm_version_is("0.10.2"): + logits = self.model.compute_logits( + last_hidden_states[:batch_size], None) + else: + logits = self.model.compute_logits( + last_hidden_states[:batch_size]) + + # TODO(wenlong): get more than one token for tree attention + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() + + # [batch_size, num_speculative_tokens] + draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) + return draft_token_ids + + def _prepare_inputs( + self, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + num_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + # [a - n1, b - n2, c - n3] -> + # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + cu_num_tokens = torch.zeros_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_target_query_lens.device, + ) + BLOCK_SIZE = 1024 + self._prepare_eagle_input_sequential( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, + block_size: int): + num_programs = len(cu_num_tokens) - 1 + for pid in range(num_programs): + start_pos = cu_num_tokens[pid].item() + end_pos = cu_num_tokens[pid + 1].item() + num_tokens = end_pos - start_pos + index_start = cu_query_lens[pid].item() + num_blocks = int( + torch.ceil(torch.tensor(num_tokens / block_size)).item()) + + for i in range(num_blocks): + offset_tensor = torch.arange(0, + block_size, + dtype=torch.int32, + device=out_tensor.device) + global_start_offset = i * block_size + target_indices = torch.tensor( + start_pos + global_start_offset, + dtype=torch.int32, + device=out_tensor.device) + offset_tensor + values_to_store = torch.tensor( + index_start + global_start_offset, + dtype=torch.int32, + device=out_tensor.device) + offset_tensor + mask = (target_indices >= start_pos) & \ + (target_indices < end_pos) & \ + (offset_tensor < num_tokens) + out_tensor[target_indices[mask]] = values_to_store[mask] diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py new file mode 100644 index 0000000..0efe93d --- /dev/null +++ b/vllm_ascend/spec_decode/interface.py @@ -0,0 +1,51 @@ +import enum +from typing import Optional + +import torch +from vllm.config import VllmConfig +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + + +class SpecDcodeType(enum.Enum): + NGRAM = 0 + EAGLE = 1 + EAGLE3 = 2 + MTP = 4 + + +class Proposer: + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device = None, + runner=None): + pass + + def load_model(self, model): + """Called by load_model in model_runner""" + raise NotImplementedError + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None): + """Called by dummy_run in modle_runner""" + raise NotImplementedError + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + """Called by execute_model in model_runner""" + raise NotImplementedError diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py new file mode 100644 index 0000000..d0a0d50 --- /dev/null +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -0,0 +1,657 @@ +import types + +import torch +import torch.nn as nn +import torchair +from torchair import patch_for_hcom +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, set_default_torch_dtype) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType +from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ + TorchairDeepSeekMTP +from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, + TorchairCommonAttentionMetadata) +from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, + vllm_version_is) + +PADDING_SLOT_ID = -1 + + +class MtpProposer(Proposer): + + def __init__( + self, + vllm_config: VllmConfig, + device, + runner, + ): + self.name = SpecDcodeType.MTP + self.vllm_config = vllm_config + self.device = device + self.runner = runner + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + + # persistent buffers for graph + self.input_ids = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int64, + device=self.device) + self.hidden_states = torch.zeros( + (self.runner.max_num_tokens, + vllm_config.model_config.get_hidden_size()), + dtype=self.runner.dtype, + device=self.device) + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore + self.torchair_graph_enabled = get_ascend_config( + ).torchair_graph_config.enabled + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=self.runner.device, + dtype=torch.int32) + + def load_model(self, model) -> None: + loader = get_model_loader(self.vllm_config.load_config) + + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + target_device = self.vllm_config.device_config.device + + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + if self.torchair_graph_enabled or ( + self.enable_shared_expert_dp + and self.vllm_config.model_config.use_mla): + self.model = TorchairDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + else: + self.model = CustomDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = list(draft_attn_layer_names) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp=None) -> None: + if not self.torchair_graph_enabled: + # TODO: adapt enable_dbo later + (num_tokens, num_tokens_across_dp, with_prefill, + _) = self.runner._sync_metadata_across_dp(num_tokens, + with_prefill, False) + + moe_comm_type = self.runner._select_moe_comm_method( + num_tokens, with_prefill) + + is_running_torchair = self.torchair_graph_enabled and \ + not with_prefill + + if is_running_torchair: + skip_attn = False + if skip_attn: + attn_metadata = None + else: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) + + input_ids = self.input_ids[:num_tokens] + positions = self.positions[:num_tokens] + previous_hidden_states = self.hidden_states[:num_tokens] + for _ in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=0): + if is_running_torchair: + assert attn_metadata is not None + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(previous_hidden_states) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + torchair_compiled_model( + input_ids=input_ids, + positions=positions, + previous_hidden_states=previous_hidden_states, + inputs_embeds=None, + intermediate_tensors=None, + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:], + spec_step_idx=0) + else: + self.model(input_ids=input_ids, + positions=positions, + previous_hidden_states=previous_hidden_states) + if with_prefill: + break + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + accepted_token_indices = None + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, accepted_token_indices, target_token_ids, \ + target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + self.runner.input_ids[:num_scheduled_tokens], + positions[:num_scheduled_tokens], + hidden_states[:num_scheduled_tokens], + attn_metadata.slot_mapping[:num_scheduled_tokens], + is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(), + ) + + draft_token_ids = self._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + token_indices=accepted_token_indices) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _prepare_inputs( + self, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + token_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + slot_mapping: torch.Tensor, + is_torchair_graph: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + if is_torchair_graph: + cu_num_tokens = cu_target_query_lens + relative_index = query_len_per_req - num_rejected_tokens - 1 + token_indices = cu_num_tokens[:-1] + relative_index + # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model + target_token_ids = token_ids + target_positions = positions + target_hidden_states = hidden_states + target_slot_mapping = slot_mapping + else: + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.zeros( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + BLOCK_SIZE = 1024 + self._prepare_input_kernel( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + target_token_ids = token_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = slot_mapping[token_indices] + return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping + + def _propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + token_indices=None) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + if token_indices is not None and self.torchair_graph_enabled: + last_token_indices = token_indices + + self.input_ids[last_token_indices] = next_token_ids + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + # FIXME: reorder_batch() needs to be called before build() + # because fields of attn_metadata_builder needs to be updated. + # However, currently reorder_batch() takes input_batch and + # scheduler_output as arguments, we should probably refactor + # the method to use new data structures which are independent + # from input_batch and scheduler_output. + # self.runner.attn_metadata_builder.reorder_batch( + # input_batch=self.runner.input_batch, + # scheduler_output=self.runner.scheduler_output, + # ) + is_running_torchair = self.torchair_graph_enabled and \ + not self.runner.with_prefill + + if is_running_torchair: + # Torchair graph mode, padding is same as the main model + num_input_tokens = self.runner.graph_pad_size + elif (self.runner.use_aclgraph + and num_tokens <= self.runner.aclgraph_batch_sizes[-1]): + # Acl graph mode, add padding to the batch size + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + # Eager mode, no padding needed + num_input_tokens = num_tokens + + seq_lens = target_positions[last_token_indices] + 1 + seq_lens = seq_lens.int() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens[:batch_size + 1], + query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), + seq_lens_cpu=seq_lens.cpu(), + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + graph_pad_size=self.runner.graph_pad_size, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + + if not self.torchair_graph_enabled: + if vllm_version_is("0.10.2"): + builder = self.runner.attn_groups[0][0].metadata_builder + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build(0, common_attn_metadata, + self.runner.get_model()) + + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp + + else: + attn_metadata = self.runner.attn_metadata_builder.build( + 0, common_attn_metadata, self.runner.get_model()) + + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + + if not self.torchair_graph_enabled: + # torch mode need to update num_tokens_across_dp + # TODO: adapt enable_dbo later + (num_input_tokens, num_tokens_across_dp, with_prefill, + _) = self.runner._sync_metadata_across_dp( + num_input_tokens, self.runner.with_prefill, False) + else: + # torchair mode can reuse self.runner.num_tokens_across_dp + num_tokens_across_dp = self.runner.num_tokens_across_dp + with_prefill = self.runner.with_prefill + + moe_comm_type = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=False) + aclgraph_runtime_mode, batch_descriptor = \ + self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) + + for step in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=num_tokens): + with ProfileExecuteDuration().capture_async('mtp_forward'): + model_kwargs = {} + model_kwargs["attn_metadata"] = attn_metadata + if self.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] + if is_running_torchair: + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_input_tokens) + hidden_states = torchair_compiled_model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + inputs_embeds=None, + intermediate_tensors=None, + spec_step_idx=0, + **model_kwargs) + else: + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + kv_caches=self.runner.kv_caches[-1:]) + + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable(): + if not self.runner.with_prefill: + max_num_reqs_across_dp = num_input_tokens + else: + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, + (0, max_num_reqs_across_dp - num_indices)) + + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if lmhead_tp_enable() and num_indices < logits.shape[0]: + logits = logits[:num_indices] + draft_token_ids = logits.argmax(dim=-1) + + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + if step == 0: + draft_token_ids_list = [draft_token_ids] + else: + draft_token_ids_list.append(draft_token_ids) + + # prepare next mtp inputs + # mtp>1: prefill skip or decode skip last loop + if with_prefill and self.torchair_graph_enabled: + for _ in range(self.num_speculative_tokens - 1): + draft_token_ids_list.append(draft_token_ids) + if step == self.num_speculative_tokens - 1 or with_prefill: + break + + if not self.torchair_graph_enabled: + attn_metadata_i = attn_metadata[self.attn_layer_name[0]] + else: + attn_metadata_i = attn_metadata + + if step == 0: + positions = target_positions[last_token_indices] + hidden_states = hidden_states[last_token_indices] + slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] + attn_metadata_i.slot_mapping.fill_(-1) + attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] + last_token_indices = self.arange[:batch_size] + if attn_metadata_i.num_decode_tokens != 0: + attn_metadata_i.num_decode_tokens = batch_size + if is_running_torchair: + attn_metadata_i.num_actual_tokens = batch_size + attn_metadata_i.query_lens = [1] * batch_size + + input_ids = draft_token_ids_list[-1].int() + positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.runner.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + # Increment the sequence lengths. + attn_metadata_i.seq_lens[:batch_size] += 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + exceeds_max_model_len_cpu = exceeds_max_model_len.to( + attn_metadata_i.seq_lens.device, non_blocking=True) + attn_metadata_i.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len_cpu, 1) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping += 1 + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:hidden_states.shape[0]] = hidden_states + attn_metadata_i.slot_mapping[:batch_size] = slot_mapping + + if attn_metadata_i.prefill is not None: + attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.input_positions = self.positions[: + num_input_tokens] + attn_metadata_i.prefill.max_seq_lens += 1 + attn_metadata_i.prefill.max_seq_lens = min( + attn_metadata_i.prefill.max_seq_lens, + self.runner.model_config.max_model_len) + if attn_metadata_i.decode is not None: + attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.decode.input_positions = self.positions[: + num_input_tokens] + attn_metadata_i.decode.max_seq_lens += 1 + attn_metadata_i.decode.max_seq_lens = min( + attn_metadata_i.decode.max_seq_lens, + self.runner.model_config.max_model_len) + + # mtp>1: [batch_size, k] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ + -1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.runner.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + config.experimental_config.enable_view_optimize = \ + get_ascend_config().torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.runner.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=not get_ascend_config().use_sfa, + fullgraph=True, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=not get_ascend_config().use_sfa, + fullgraph=True, + cache_dir=TORCHAIR_CACHE_DIR, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] + + # TODO Using torch instead of triton may result in poor performance + def _prepare_input_kernel(self, out_ptr: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, block_size: int): + device = cu_query_lens.device + dtype = out_ptr.dtype + + offsets = torch.arange(block_size, device=device, dtype=dtype) + start_pos = cu_num_tokens[:-1] + end_pos = cu_num_tokens[1:] + num_tokens = end_pos - start_pos + + global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) + values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) + + mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) + + global_indices_flat = global_indices[mask] + values_flat = values[mask] + out_ptr[global_indices_flat] = values_flat diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py new file mode 100644 index 0000000..9999f1f --- /dev/null +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -0,0 +1,65 @@ +import torch +from vllm.v1.spec_decode.ngram_proposer import \ + NgramProposer as VllmNgramProposer + +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + + +class NgramProposer(VllmNgramProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.NGRAM + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + # TODO(woosuk): Optimize. + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(valid_sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require top-p, top-k, etc. + req_id = self.runner.input_batch.req_ids[i] + if req_id in self.runner.input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.runner.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + num_sampled_ids + self.runner.input_batch.token_ids_cpu[ + i, start_idx:end_idx] = sampled_ids + drafter_output = self.propose( + self.runner.input_batch.token_ids_cpu[i, :end_idx]) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py index 3537aa8..6e4990d 100644 --- a/vllm_ascend/torchair/models/qwen2.py +++ b/vllm_ascend/torchair/models/qwen2.py @@ -40,7 +40,6 @@ from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401 from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config @@ -343,9 +342,9 @@ class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): return hidden_states def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: torch.Tensor, + sampling_metadata=None, # type: ignore ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index dd4a592..c6aad6a 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -54,8 +54,9 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, - init_metadata_for_sp) +from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, + init_metadata_for_sp) +from vllm_ascend.utils import vllm_version_is class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -311,9 +312,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp") else: - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if vllm_version_is("0.10.2"): + self.mlp = Qwen3MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -394,7 +400,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - self.num_redundant_experts = parallel_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py index 6cb98a5..a7c5a6e 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py @@ -27,14 +27,12 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm_ascend.torchair.models.torchair_deepseek_v2 import \ @@ -172,7 +170,7 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + sampling_metadata=None, # type: ignore spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) @@ -199,8 +197,6 @@ class TorchairDeepSeekMTP(DeepSeekMTP): self.model = TorchairDeepSeekMultiTokenPredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.sampler = get_sampler() - def forward( self, input_ids: torch.Tensor, diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index b31549d..8cf6e24 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -32,8 +32,7 @@ import torch_npu from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, split_tensor_along_last_dim, @@ -52,7 +51,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -69,12 +67,14 @@ from vllm.model_executor.models.utils import ( make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.models.layers.sfa import Indexer from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ TorchairAscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor, npu_prefetch +from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable class TorchairDeepseekV2SiluAndMul(SiluAndMul): @@ -322,8 +322,8 @@ class TorchairDeepseekV2MoE(nn.Module): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and \ + self.multistream_overlap_shared_expert = \ + ascend_config.multistream_overlap_shared_expert and \ self.torchair_graph_enabled self.gate = ReplicatedLinear(config.hidden_size, @@ -364,7 +364,7 @@ class TorchairDeepseekV2MoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=reduce_results, - force_replicate=self.enable_multistream_moe + force_replicate=self.multistream_overlap_shared_expert or enable_shared_expert_dp, prefix=f"{prefix}.shared_experts", ) @@ -377,10 +377,6 @@ class TorchairDeepseekV2MoE(nn.Module): self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group self.ep_group = get_ep_group() - self.kv_consumer = None - transfer_config = get_current_vllm_config().kv_transfer_config - if transfer_config is not None: - self.kv_consumer = transfer_config.kv_role == "kv_consumer" self.params_dtype = torch.get_default_dtype() self.rm_router_logits = self.experts.rm_router_logits @@ -398,15 +394,9 @@ class TorchairDeepseekV2MoE(nn.Module): is_prefill = forward_context.with_prefill - # If this node is kv_consumer, we force the moe always runs in decode path to make sure - # the behaviour aligned between dummy_run and normal model_execute. - if self.kv_consumer: - is_prefill = False - enable_force_load_balance = False - # router_logits: (num_tokens, n_experts) router_logits = None - if not self.rm_router_logits and not self.enable_multistream_moe: + if not self.rm_router_logits and not self.multistream_overlap_shared_expert: router_logits, _ = self.gate(hidden_states) experts_hidden_states = self.experts( @@ -447,6 +437,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + decoder_layer=None, ) -> None: nn.Module.__init__(self) self.hidden_size = hidden_size @@ -514,11 +505,18 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - if (config.n_routed_experts is not None - and self.debug_layer_idx >= config.first_k_dense_replace - and self.debug_layer_idx % config.moe_layer_freq == 0 - and (ascend_config.torchair_graph_config.enable_multistream_moe - or self.enable_shared_expert_dp)): + + if oproj_tp_enable(): + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + elif (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.multistream_overlap_shared_expert + or self.enable_shared_expert_dp)): self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( self.num_heads * self.v_head_dim, self.hidden_size, @@ -635,6 +633,225 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): output_shape=output_shape) +class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer=None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + if (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.multistream_overlap_shared_expert + or self.enable_shared_expert_dp)): + self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + else: + self.o_proj = TorchairDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + self.sfa_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sfa=True, + # SFA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + indexer=self.indexer, + decoder_layer=decoder_layer, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + if not self.torchair_graph_enabled: + if forward_context.attn_metadata is not None and isinstance( + forward_context.attn_metadata, dict): + attn_metadata = next( + iter(forward_context.attn_metadata.values()), None) + else: + attn_metadata = forward_context.attn_metadata + if kv_cache is None: + kv_cache = self.sfa_attn.kv_cache[ + forward_context.virtual_engine] + + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + # if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # # Simulate all gather to calculate output shape + # num_tokens = num_tokens * self.tp_size + # need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace: + output_shape = hidden_states.shape + if self.enable_shared_expert_dp and ( + self.debug_layer_idx == self.first_k_dense_replace + or self.debug_layer_idx == self.layers): + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + output = output.view(-1, output_shape[-1]) + return output + + class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): def __init__( @@ -659,9 +876,16 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group ascend_config = get_ascend_config() + self.use_mla = False + self.use_sfa = False # TODO: enable mla in vllm-ascend if model_config.use_mla: - attn_cls = TorchairDeepseekV2MLAAttention + if ascend_config.use_sfa: + attn_cls = TorchairDeepseekV2SFAAttention + self.use_sfa = True + else: + attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment] + self.use_mla = True else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( @@ -680,6 +904,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + decoder_layer=self, ) if (config.n_routed_experts is not None @@ -690,7 +915,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \ + self.mla_moe_communication = ascend_config.multistream_overlap_shared_expert \ and model_config.use_mla and self.tp_size > 1 else: self.mlp = TorchairDeepseekV2MLP( @@ -720,21 +945,34 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention - if attn_metadata is not None and attn_metadata.num_decodes > 0: - mla_moe_communication = self.mla_moe_communication and replace_allreduce + if attn_metadata is not None: + decoding_condition_met = ( + not attn_metadata.is_prefill if self.use_sfa else + attn_metadata.num_decodes > 0 if self.use_mla else False) + mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce else: mla_moe_communication = False - if residual is None: + + forward_context = get_forward_context() + if (envs.VLLM_ASCEND_ENABLE_MLAPO + and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention) + and attn_metadata is not None + and not forward_context.with_prefill): + if residual is not None: + hidden_states = hidden_states + residual residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) else: - previous_hidden_states, previous_residual = hidden_states, residual - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - # Dispose hidden_states and residual from the previous layer - # to save npu memory because they're no longer used. - dispose_tensor(previous_hidden_states) - dispose_tensor(previous_residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) @@ -806,6 +1044,8 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): residual = get_tp_group().all_gather(residual, 0) attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values()), None) if attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens else: @@ -921,6 +1161,8 @@ class TorchairDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config + self.num_dense_layers = self.config.first_k_dense_replace + self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers self.quant_config = quant_config self.model = TorchairDeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix( @@ -934,7 +1176,6 @@ class TorchairDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py index eb05760..195ffde 100644 --- a/vllm_ascend/torchair/models/torchair_pangu_moe.py +++ b/vllm_ascend/torchair/models/torchair_pangu_moe.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -53,9 +52,9 @@ from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.models.utils import ( extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from vllm.v1.sample.sampler import Sampler from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p @@ -913,7 +912,7 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() + self.sampler = Sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -935,19 +934,19 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP): return hidden_states def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: torch.Tensor, + sampling_metadata=None, # type: ignore ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + self, + logits: Optional[torch.Tensor], + sampling_metadata, # type: ignore + ): next_tokens = self.sampler(logits, sampling_metadata) return next_tokens diff --git a/vllm_ascend/ops/sequence_parallel.py b/vllm_ascend/torchair/ops/sequence_parallel.py similarity index 100% rename from vllm_ascend/ops/sequence_parallel.py rename to vllm_ascend/torchair/ops/sequence_parallel.py diff --git a/vllm_ascend/torchair/ops/shared_weight_layer.py b/vllm_ascend/torchair/ops/shared_weight_layer.py new file mode 100644 index 0000000..6ab29af --- /dev/null +++ b/vllm_ascend/torchair/ops/shared_weight_layer.py @@ -0,0 +1,245 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.linear import LinearBase + + +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty([], device=x.device, dtype=x.dtype)) + + +@dataclass +class LayerMetadata: + """Metadata for a layer. + """ + layer: Optional[LinearBase] # The layer object. + post_method: Callable[[ + torch.nn.Module + ], None] # The `process_weights_after_loading` method from the quant method. + weight: torch.Tensor # The weight tensor. + window_idx: int # The index of the window. + + +@dataclass +class SharedWindowMetadata: + """Metadata for a shared window. + """ + weight: torch.Tensor # The weight tensor to be shared by layers. + data_layer_idx: int # The index of the layer this window's weight is equal to. + work: Optional[torch.distributed.Work] # The asynchronous broadcast work. + + +@dataclass +class SeriesMetadata: + """Metadata for a weight shared series. + """ + group: GroupCoordinator + start_layer: int + end_layer: int + num_layers: int + prefetch_step: int + dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. + layers: list[LayerMetadata] + shared_windows: list[ + SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. + window_offset: int # The index of the window for the next coming layer. + + def is_source(self, layer_idx) -> bool: + return layer_idx % self.group.world_size == self.group.rank_in_group + + def post_process_after_loading(self): + # This method only needs to be called once per series. + if self.shared_windows: + return + for layer_idx in range(self.start_layer, self.end_layer): + layer = self.layers[layer_idx - self.start_layer] + is_source = self.is_source(layer_idx) + # If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. + if not is_source: + layer.weight.set_(torch.empty_like(self.dummy_weight)) + # Broadcast to get the true weight. + dist.broadcast(layer.weight, + src=self.group.ranks[layer_idx % + self.group.world_size], + group=self.group.device_group) + assert layer.layer is not None + # Call `process_weights_after_loading` from the quant method. + layer.post_method(layer.layer) + step = layer_idx - self.start_layer + if step < self.prefetch_step: + # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. + self.shared_windows.append( + SharedWindowMetadata( + weight=layer.weight.clone().detach(), + data_layer_idx=layer_idx, + work=None, + )) + layer.window_idx = step + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not is_source: + layer.weight.set_(self.shared_windows[-1].weight) + else: + # Build one more window for prefetch. The weight is useless, so just keep the shape. + if step == self.prefetch_step: + self.shared_windows.append( + SharedWindowMetadata( + weight=torch.empty_like(layer.weight), + data_layer_idx=-1, + work=None, + )) + # When the layer not intended to be stored in this device, dispose the tensor. + if not is_source: + dispose_tensor(layer.weight) + + dispose_tensor(self.dummy_weight) + + def reach_layer(self, layer_idx: int): + # The index of the layer to be prefetched. + next_layer_idx = (layer_idx + self.prefetch_step + ) % self.num_layers + self.start_layer + next_layer = self.layers[next_layer_idx - self.start_layer] + # The index of the window to store the weight for the coming layer. + next_layer.window_idx = self.window_offset + window = self.shared_windows[next_layer.window_idx] + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not self.is_source(next_layer_idx): + next_layer.weight.set_(window.weight) + # Update `window_offset` by rolling one step. + self.window_offset = (self.window_offset + 1) % (self.prefetch_step + + 1) + assert window.data_layer_idx != next_layer_idx + window.data_layer_idx = next_layer_idx + # Start asynchronous broadcast work. + window.work = dist.broadcast( + next_layer.weight, + src=self.group.ranks[next_layer_idx % self.group.world_size], + group=self.group.device_group, + async_op=True) + + def wait_weight(self, layer_idx: int): + # Find the asynchronous broadcast work and wait for it. + assert self.shared_windows + window = self.shared_windows[self.layers[layer_idx - + self.start_layer].window_idx] + # Make sure the data in the corresponding shared window is for the current layer. + assert window.data_layer_idx == layer_idx + if window.work is not None: + window.work.wait() + window.work = None + + +@dataclass +class LayerExternalMetadata: + """External metadata for a layer. + """ + series: SeriesMetadata + layer_idx: int + + +_series_dict: dict[str, SeriesMetadata] = {} + +_layer_external_dict: dict[int, LayerExternalMetadata] = {} + + +def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, + layer_idx: int) -> Callable: + + def wrapped_forward(*args, **kwargs): + # Wait for the weight. + series.wait_weight(layer_idx) + return forward(*args, **kwargs) + + return wrapped_forward + + +""" +Register linear layers into a shared storage series. + +In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. + +After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. + +During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. + +Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: +- total_layers = end_layer - start_layer +- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer + +To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. + +Arguments: + series_name: This name identifies which series this layer belongs to. + group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. + start_layer: The index of the first layer in the series (inclusive). + end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). + layer_idx: The index of the current layer. + layer: The linear layer object to register. + prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. +""" + + +def register_layer_to_shared_weight_series( + series_name: str, + group: GroupCoordinator, + start_layer: int, + end_layer: int, + layer_idx: int, + layer: LinearBase, + prefetch_step: int = 1, +): + global _series_dict + if series_name not in _series_dict: + num_layers = end_layer - start_layer + assert num_layers > 0 + assert prefetch_step >= 0 and prefetch_step <= num_layers - 2 + _series_dict[series_name] = SeriesMetadata( + group=group, + start_layer=start_layer, + end_layer=end_layer, + num_layers=num_layers, + prefetch_step=prefetch_step, + dummy_weight=torch.empty_like(layer.weight), + layers=[ + LayerMetadata( + layer=None, + post_method=lambda layer: None, + weight=torch.empty([]), + window_idx=-1, + ) for _ in range(num_layers) + ], + shared_windows=[], + window_offset=prefetch_step, + ) + series = _series_dict[series_name] + assert layer.quant_method is not None + series.layers[layer_idx - start_layer] = LayerMetadata( + layer=layer, + post_method=layer.quant_method.process_weights_after_loading, + weight=layer.weight, + window_idx=-1, + ) + # Discard the original `process_weights_after_loading` method such that it won't be called by others. + layer.quant_method.process_weights_after_loading = lambda layer: None + # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. + if not series.is_source(layer_idx): + dispose_tensor(layer.weight) + layer.weight.weight_loader = lambda *args, **kwargs: None + layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx) + global _layer_external_dict + _layer_external_dict[id(layer)] = LayerExternalMetadata( + series=series, + layer_idx=layer_idx, + ) + + +def post_process_after_loading_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.post_process_after_loading() + + +def reach_layer_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.reach_layer(ext.layer_idx) diff --git a/vllm_ascend/distributed/communication_op.py b/vllm_ascend/torchair/ops/torchair_activation.py similarity index 52% rename from vllm_ascend/distributed/communication_op.py rename to vllm_ascend/torchair/ops/torchair_activation.py index 2e475f5..0721ea0 100644 --- a/vllm_ascend/distributed/communication_op.py +++ b/vllm_ascend/torchair/ops/torchair_activation.py @@ -1,25 +1,37 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import torch -from vllm.distributed.parallel_state import get_dp_group - - -def data_parallel_reduce_scatter(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: - """Reduce-Scatter the input tensor across data parallel group.""" - return get_dp_group().reduce_scatter(input_, dim) +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch + + +def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: + """AscendSiluAndMul forward in torchair mode. + + The key difference from the original implementation is the removal of operators + from the torch.ops.vllm class, as these operators only function in non-torchair + modes. Adding them back would cause the graph compilation to fail. + """ + + import torch_npu + + from vllm_ascend.utils import is_310p + + if is_310p(): + out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) + else: + out = torch_npu.npu_swiglu(x) + return out diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index bd2be21..bd25a79 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -40,17 +40,18 @@ from vllm.model_executor.layers.quantization.base_config import \ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.communication_op import \ - data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, + determine_default_log2phy_map) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod +from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, get_ascend_soc_version, - get_rm_router_logits_state, is_310p) + get_rm_router_logits_state, is_310p, + vllm_version_is) def torchair_fused_experts_with_mc2( @@ -802,6 +803,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp try: device_group = get_mc2_group().device_group @@ -883,6 +885,8 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All if fused_moe_state == FusedMoEState.MC2: return torchair_fused_experts_with_mc2( @@ -1013,45 +1017,70 @@ class TorchairAscendFusedMoE(FusedMoE): self.moe_parallel_config.ep_size, is_deepseek_v3_r1) ascend_config = get_ascend_config() - expert_map_path = ascend_config.expert_map_path - if expert_map_path and os.path.exists(expert_map_path): - # moe expert load balance - expert_load_balancer = ExpertLoadBalancer(expert_map_path, - self.global_num_experts) - self.local_num_experts, self.expert_map = \ - expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, - get_ep_group().rank_in_group) - self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, - get_ep_group().rank_in_group) - self.global_redundant_expert_num = \ - expert_load_balancer.get_global_redundant_expert_num() + self.dynamic_eplb = ascend_config.dynamic_eplb + self.expert_map_path = ascend_config.expert_map_path + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.global_num_experts = num_experts + self.global_redundant_expert_num + # static eplb initializing with expert_map_path + if self.expert_map_path and os.path.exists( + self.expert_map_path) and os.access(self.expert_map_path, + os.R_OK): + self.expert_load_balancer = ExpertLoadBalancer( + self.expert_map_path, self.global_num_experts) + self.local_num_experts, self.expert_map = ( + self.expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank)) + self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, self.ep_rank).npu() + self.global_redundant_expert_num = ( + self.expert_load_balancer.get_global_redundant_expert_num()) else: - # Create a tensor of size num_experts filled with -1 + # init moe. self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + self.ep_size, self.ep_rank, self.global_num_experts) + # dynamic eplb initializing with not expert_map_path + if self.dynamic_eplb: + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.local_num_experts, self.expert_map = determine_default_expert_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and \ + self.multistream_overlap_shared_expert = \ + ascend_config.multistream_overlap_shared_expert and \ self.torchair_graph_enabled self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - self.moe = FusedMoEConfig.make( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. - in_dtype=params_dtype, - quant_config=quant_config) + if vllm_version_is("0.10.2"): + self.moe = FusedMoEConfig.make( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + # TODO (bnell): this needs to be fixed for quantized types. + in_dtype=params_dtype, + quant_config=quant_config) + else: + self.moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, + ) if quant_config is None: self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( self.moe) @@ -1066,8 +1095,11 @@ class TorchairAscendFusedMoE(FusedMoE): assert self.quant_method is not None - local_num_experts = torch.sum(self.expert_map != -1) \ - if self.expert_map is not None else num_experts + self.moe_load = None + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) moe_quant_params = { "num_experts": local_num_experts, @@ -1126,23 +1158,25 @@ class TorchairAscendFusedMoE(FusedMoE): forward_context = get_forward_context() fused_moe_state = forward_context.fused_moe_state mc2_mask = forward_context.mc2_mask + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None - from vllm_ascend.quantization.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod - if self.enable_multistream_moe: + from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ + TorchairAscendW8A8DynamicFusedMoEMethod + if self.multistream_overlap_shared_expert: if not self.rm_router_logits: router_logits, _ = gate(hidden_states) if hasattr(self.quant_method, "quant_method") and \ isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod + TorchairAscendW8A8DynamicFusedMoEMethod ) and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0): quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( hidden_states) if shared_experts: - if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: + if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2: # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce shared_hidden_states = shared_experts(hidden_states) @@ -1160,31 +1194,33 @@ class TorchairAscendFusedMoE(FusedMoE): if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast - ] and not replace_allreduce): - if fused_moe_state in {FusedMoEState.MC2}: - padding_size = forward_context.padded_num_tokens - else: - # TODO: Determine if we can remove the padding - padding_size = tp_size - if num_tokens < padding_size and not self.enable_shared_expert_dp: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, padding_size - num_tokens)) - router_logits = nn.functional.pad( - router_logits, (0, 0, 0, padding_size - num_tokens)) + ]): if tp_size > 1: tp_rank = get_tensor_model_parallel_rank() - if not self.enable_shared_expert_dp: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) mc2_mask = chunk_mc2_mask[tp_rank] + if not replace_allreduce: + if fused_moe_state in {FusedMoEState.MC2}: + padding_size = forward_context.padded_num_tokens + else: + # TODO: Determine if we can remove the padding + padding_size = tp_size + if num_tokens < padding_size and not self.enable_shared_expert_dp: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, padding_size - num_tokens)) + router_logits = nn.functional.pad( + router_logits, (0, 0, 0, padding_size - num_tokens)) + if tp_size > 1: + tp_rank = get_tensor_model_parallel_rank() + if not self.enable_shared_expert_dp: + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] if self.dp_size > 1: if fused_moe_state == FusedMoEState.AllGather: @@ -1206,8 +1242,12 @@ class TorchairAscendFusedMoE(FusedMoE): router_logits = get_dp_group().all_gather(router_logits, 0) elif fused_moe_state == FusedMoEState.NaiveMulticast: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + if vllm_version_is("0.10.2"): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + else: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_sp(1) hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) if self.rm_router_logits: @@ -1236,7 +1276,8 @@ class TorchairAscendFusedMoE(FusedMoE): log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled - and self.enable_multistream_moe and not is_prefill else None, + and self.multistream_overlap_shared_expert and not is_prefill else + None, mc2_mask=mc2_mask, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, @@ -1246,6 +1287,11 @@ class TorchairAscendFusedMoE(FusedMoE): if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states + if self.dynamic_eplb and isinstance( + e_hidden_states, tuple) and len(e_hidden_states) == 3: + self.moe_load += e_hidden_states[2] if e_hidden_states[1] == 0 else \ + torch.cat(e_hidden_states[2][:1], e_hidden_states[2][1:] - e_hidden_states[2][:-1]) + if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast @@ -1269,8 +1315,8 @@ class TorchairAscendFusedMoE(FusedMoE): final_hidden_states = final_hidden_states[start:end, :] dispose_tensor(e_hidden_states) elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = data_parallel_reduce_scatter( - e_hidden_states, dim=0) + final_hidden_states = get_dp_group().reduce_scatter( + e_hidden_states, 0) final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) else: @@ -1290,6 +1336,19 @@ class TorchairAscendFusedMoE(FusedMoE): else: return final_hidden_states + def update_expert_map(self, new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.logical_to_physical_map + + def clear_moe_load(self): + if self.moe_load is not None: + self.moe_load.zero_() + # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_fused_moe_comp( diff --git a/vllm_ascend/torchair/ops/torchair_layernorm.py b/vllm_ascend/torchair/ops/torchair_layernorm.py new file mode 100644 index 0000000..d90f889 --- /dev/null +++ b/vllm_ascend/torchair/ops/torchair_layernorm.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from typing import Optional, Tuple, Union + +import torch + + +def torchair_rmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """AscendRMSNorm forward in torchair mode. + + The key difference from the original implementation is the removal of operators + from the torch.ops.vllm class, as these operators only function in non-torchair + modes. Adding them back would cause the graph compilation to fail. + """ + + import torch_npu + + from vllm_ascend.utils import is_310p + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + return x, residual + + x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + return x diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py index 5793288..e64bd6f 100644 --- a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py +++ b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py @@ -62,7 +62,7 @@ def rope_forward_oot( # adopt custom kernel path for rotary_embedding if custom_rotary_embedding_enabled(query, neox_style, self.head_size) and not is_310p(): - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, @@ -93,10 +93,7 @@ def native_rope_deepseek_forward(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) + offsets: Optional[torch.Tensor] = None): if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style @@ -211,8 +208,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len +def _set_cos_sin_cache(self, max_seq_len, device, dtype): dim = self.rotary_dim freq_extra = 1.0 / (self.base**( @@ -232,9 +228,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale @@ -365,8 +359,7 @@ def deepseek_rope_init_func( super(DeepseekScalingRotaryEmbedding, self).__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - _set_cos_sin_cache(self, - max_position_embeddings, - dtype=dtype, - device="npu") + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) + _set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu") diff --git a/vllm_ascend/torchair/quantization/torchair_quantizer.py b/vllm_ascend/torchair/quantization/torchair_quantizer.py deleted file mode 100644 index 1d1d584..0000000 --- a/vllm_ascend/torchair/quantization/torchair_quantizer.py +++ /dev/null @@ -1,29 +0,0 @@ -from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer -from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( - TorchairAscendW4A8DynamicFusedMoEMethod, - TorchairAscendW4A8DynamicLinearMethod) -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( - TorchairAscendW8A8DynamicFusedMoEMethod, - TorchairAscendW8A8DynamicLinearMethod) - - -class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return TorchairAscendW8A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return TorchairAscendW8A8DynamicFusedMoEMethod() - - -class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return TorchairAscendW4A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return TorchairAscendW4A8DynamicFusedMoEMethod() diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index f38e2d8..02deee8 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -139,6 +139,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 256) + # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process + self.is_per_channel_weight = self.group_size == 0 quant_version = vllm_config.quant_config.quant_description.get( "version", "0") # NOTE: new quantize weights: 2 int4 pack into int8 @@ -188,44 +190,45 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w13_weight_offset"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) - - param_dict["w13_weight_scale_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) - - param_dict["w13_weight_offset_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) - param_dict["w2_weight_scale_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) - param_dict["w2_weight_offset_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) + dtype=torch.float32) + + if not self.is_per_channel_weight: + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) if self.new_quant_version: param_dict["w13_scale_bias"] = torch.empty( @@ -318,8 +321,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, @@ -343,8 +346,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, @@ -357,6 +360,14 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: ) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + scale = scale.transpose(1, 2).contiguous() + if self.is_per_channel_weight: + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() group_num, k, n = weight.shape # the weight of the new version is reduced by half by pack n, so it needs to be restored if self.new_quant_version: @@ -399,13 +410,10 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: def pack_to_int32(self, weight: torch.Tensor): if self.new_quant_version: - group_num, k, n = weight.shape - assert n % 4 == 0, "the last dim of weight needs to be divided by 4" - packed_n = n // 4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 - packed_weight = torch.from_numpy( - np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) - return packed_weight.reshape(group_num, k, packed_n).npu() + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + return weight.view(torch.int32).contiguous() else: return torch_npu.npu_quantize(weight.to(torch.float32), torch.tensor([1.]).npu(), None, @@ -417,21 +425,22 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( - 1, 2).contiguous() - - layer.w13_weight_scale_second.data, w13_bias = self.process_scale( + w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( + layer, "w13_weight_scale_second") else None + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( + layer, "w2_weight_scale_second") else None + layer.w13_weight_scale.data, w13_bias = self.process_scale( layer.w13_weight, layer.w13_weight_scale.data, - layer.w13_weight_scale_second.data) - layer.w2_weight_scale_second.data, w2_bias = self.process_scale( + w13_weight_scale_second) + layer.w2_weight_scale.data, w2_bias = self.process_scale( layer.w2_weight, layer.w2_weight_scale.data, - layer.w2_weight_scale_second.data) + w2_weight_scale_second) + if hasattr(layer, "w13_weight_scale_second"): + # scale_second is no longer used, release this part of the memory + del layer.w13_weight_scale_second + del layer.w2_weight_scale_second + del layer.w13_weight_offset_second + del layer.w2_weight_offset_second self.update_bias(layer, w13_bias, w2_bias) diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 5c3fa95..23c4699 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -23,7 +23,6 @@ import torch_npu from vllm.distributed import GroupCoordinator, get_ep_group from vllm.forward_context import get_forward_context -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -417,6 +416,7 @@ def torchair_fused_experts_with_all2all( num_experts = w1.shape[0] if expert_map is not None: + assert ep_group is not None, "ep_group must be provided when expert_map is given" global_num_experts = len(expert_map) + global_redundant_expert_num if hasattr(torch_npu, "npu_moe_init_routing_quant"): quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( @@ -436,8 +436,9 @@ def torchair_fused_experts_with_all2all( gather_sizes = global_expert_tokens.new_empty( global_expert_tokens.shape[0]) - dist.all_to_all_single(gather_sizes, global_expert_tokens) - + dist.all_to_all_single(gather_sizes, + global_expert_tokens, + group=ep_group.device_group) token_counts_combined = torch.stack( [gather_sizes, global_expert_tokens], dim=0) token_counts_combined = token_counts_combined.view( @@ -452,10 +453,16 @@ def torchair_fused_experts_with_all2all( gather_size_list = token_counts_combined_cpu[1] scatter_size_list = token_counts_combined_cpu[0] - dist.all_to_all_single(gathered_tokens, quantized_tokens, - scatter_size_list, gather_size_list) - dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, - gather_size_list) + dist.all_to_all_single(gathered_tokens, + quantized_tokens, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) + dist.all_to_all_single(dynamic_scale, + token_scales, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( gathered_tokens, @@ -503,9 +510,11 @@ def torchair_fused_experts_with_all2all( index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) - dist.all_to_all_single(hidden_states, reordered_outputs, - gather_size_list, scatter_size_list) - + dist.all_to_all_single(hidden_states, + reordered_outputs, + gather_size_list, + scatter_size_list, + group=ep_group.device_group) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, @@ -824,6 +833,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp try: device_group = get_mc2_group().device_group @@ -937,6 +947,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: ) fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All shared_gate_up, shared_dequant_scale = None, None if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0): @@ -1021,8 +1033,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP: - torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 81f2968..9f1b40e 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -98,10 +98,12 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): def __init__( self, + kv_cache_spec, + layer_names, vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.max_num_blocks_per_req = cdiv( self.model_config.max_model_len, self.vllm_config.cache_config.block_size) @@ -171,8 +173,9 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + model: Optional[nn.Module] = None, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -182,11 +185,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): block_table[:num_reqs]) seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - self.device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state @@ -374,6 +373,9 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): indices = torch.cat((block_indices, slots_indices), dim=1) torch_npu.npu_scatter_nd_update_(key_cache, indices, key) torch_npu.npu_scatter_nd_update_(value_cache, indices, value) + if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + self.key_cache = key_cache + self.value_cache = value_cache if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: assert attn_metadata is not None @@ -411,11 +413,13 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): assert attn_metadata is not None assert attn_metadata.attn_mask is not None compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] torch_npu._npu_flash_attention_qlens( query=query, key_cache=self.key_cache, value_cache=self.value_cache, - block_table=attn_metadata.block_tables, + block_table=block_table, mask=compress_mask, seq_len=attn_metadata.query_lens, context_lens=attn_metadata.seq_lens, @@ -431,17 +435,24 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): block_size = key_cache.shape[1] query = query.view(num_tokens, 1, self.num_heads * self.head_size).contiguous() - output = torch_npu.npu_incre_flash_attention( - query, - key_cache, - value_cache, - num_key_value_heads=self.num_kv_heads, + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key_cache, + value=value_cache, + query_rope=None, + key_rope=None, num_heads=self.num_heads, - actual_seq_lengths=seq_lens, - scale_value=self.scale, - block_table=block_table, + num_key_value_heads=self.num_kv_heads, input_layout='BSH', - block_size=block_size) + atten_mask=decode_meta.attn_mask, + sparse_mode=0, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens, + ) else: raise NotImplementedError( "Torchair graph mode with non-MLA attention backend is still experimental." diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 30ef293..995173a 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -23,7 +23,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, npu_stream_switch, npu_wait_tensor) from vllm_ascend.utils import npu_prefetch @@ -176,6 +175,8 @@ class AscendMLATorchairMetadataBuilder: # _attn_mask_builder = None def __init__(self, + kv_cache_spec, + layer_names, vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[AscendMLATorchairMetadata] = None): @@ -372,6 +373,7 @@ class AscendMLATorchairMetadataBuilder: def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, ) -> AscendMLATorchairMetadata: @@ -398,11 +400,7 @@ class AscendMLATorchairMetadataBuilder: device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) @@ -492,11 +490,12 @@ class AscendMLATorchairMetadataBuilder: graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size != -1 if num_decodes > 0: + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decode_tokens] + seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decode_tokens, ...] + block_table = block_table[:num_decodes, ...] num_token_pad_size = 0 if use_torchair_graph and common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, @@ -535,10 +534,9 @@ class AscendMLATorchairMetadataBuilder: device=input_positions.device) input_positions = torch.cat( [input_positions, position_padding]) - actual_seq_lengths_q = ( - actual_seq_lengths_q + common_attn_metadata. - actual_seq_lengths_q[num_reqs:num_reqs + - num_reqs_pad_size]) + actual_seq_lengths_q = self.pad_actual_seq_len_q( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + common_attn_metadata) else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) @@ -581,6 +579,48 @@ class AscendMLATorchairMetadataBuilder: enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) + def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + return actual_seq_lengths_q + class AscendMLATorchairImpl(MLAAttentionImpl): """ @@ -629,12 +669,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.running_in_graph = False + self.prefill_mask = None + self.ring_mla_mask_size = 512 - # Adapt torch air graph mode with spec decoding. - speculative_config = get_current_vllm_config().speculative_config - if speculative_config is not None: - self.spec_token_num = speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 + self.speculative_config = get_current_vllm_config().speculative_config def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): # Convert from (B, N, L) to (N, B, L) @@ -775,16 +813,13 @@ class AscendMLATorchairImpl(MLAAttentionImpl): k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=v, - mask=mask, + mask=self.prefill_mask, seqlen=seq_len, head_num=self.num_heads, kv_head_num=self.num_heads, @@ -816,104 +851,54 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.v_head_dim, dtype=query.dtype, device=query.device) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=query.device) k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache - ascend_config = get_ascend_config() + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + if self.prefill_mask is None: + if q_nope.dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(q_nope.dtype) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=torch.tensor( + attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: - attn_output_torch = torch.empty(num_tokens, - self.num_heads * self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output_torch, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_seq_lens, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) - elif attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ]: - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=query.device) - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) # 512: mask only support 512 - if attn_metadata.num_prefills > 1: - mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, - 1) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=mask, - seqlen=torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32), - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - key = torch.cat((k_nope, k_pe), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) - attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLATorchairImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: - attn_output = attn_output_torch return attn_output @@ -961,7 +946,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): kv = self.kv_a_proj_with_mqa(hidden_states)[0] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight, @@ -1019,8 +1004,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.qk_rope_head_dim) input_layout = "BNSD" - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill + ] and self.speculative_config is not None: + # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) @@ -1199,9 +1187,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) + decode_q_pe.contiguous(), decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ @@ -1226,9 +1212,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) + prefill_q_pe.contiguous(), prefill_k_pe) assert len( kv_cache diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 2b34f9b..daf6b5d 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # isort: skip_file +import math import types from typing import Optional @@ -24,7 +25,6 @@ import torch import torch.distributed as dist import torch.nn as nn import torch_npu -import vllm.envs as envs_vllm from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_dp_group @@ -40,25 +40,39 @@ from vllm_ascend.torchair.utils import ( register_torchair_model, torchair_ops_patch, torchair_quant_method_register, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_310p) + is_310p, get_ascend_soc_version, + AscendSocVersion) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner class NPUTorchairModelRunner(NPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.ascend_config = get_ascend_config() + self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp super().__init__(vllm_config, device) - ascend_config = get_ascend_config() + if self.speculative_config: + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + None, None, vllm_config, device) + + register_torchair_model() + torchair_ops_patch() + torchair_quant_method_register() + if self.enable_shared_expert_dp: + return self.new_kv_cache_bytes = -1 self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore - self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph - self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes - self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes - if ascend_config.torchair_graph_config.graph_batch_sizes_init: + self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph + self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes + self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes + if self.ascend_config.torchair_graph_config.graph_batch_sizes_init: self.init_torchair_graph_batch_sizes() - self.check_torchair_graph_batch_sizes() + self.update_torchair_graph_batch_sizes() torch._dynamo.cache_size.config.cache_size_limit += len( self.torchair_graph_batch_sizes) @@ -67,14 +81,14 @@ class NPUTorchairModelRunner(NPUModelRunner): recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) self._check_batch_sizes_consistency() - register_torchair_model() - torchair_ops_patch() - torchair_quant_method_register() def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: """Override from NPUModelRunner to pad num_tokens""" + if self.enable_shared_expert_dp: + # Padding is not required for shared_expert_dp cases in eager mode. + return num_tokens, None, with_prefill, enable_dbo if self.dp_size == 1: if not with_prefill: maybe_padded_num_tokens = self.select_torchair_padded_batch_size( @@ -107,10 +121,15 @@ class NPUTorchairModelRunner(NPUModelRunner): return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo - def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): + def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens, + max_query_len, force_attention): # NOTE: If torchair graph mode and not with_prefill, # we can't skip_attn, it will cause graph recompile. - if not with_prefill: + if with_prefill or self.enable_shared_expert_dp: + attn_metadata = super()._build_attention_metadata( + with_prefill, num_reqs, num_tokens, max_query_len, + force_attention) + else: common_attn_metadata = TorchairCommonAttentionMetadata( num_reqs=num_reqs, num_actual_tokens=1, @@ -121,17 +140,19 @@ class NPUTorchairModelRunner(NPUModelRunner): ) attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( common_attn_metadata) - else: - attn_metadata = super()._build_attention_metadata( - with_prefill, num_reqs, skip_attn) return attn_metadata def _generate_dummy_run_hidden_states(self, with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds): - - if not with_prefill: + if with_prefill or self.enable_shared_expert_dp: + if is_310p(): + converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + hidden_states = super()._generate_dummy_run_hidden_states( + with_prefill, is_torchair_compile, input_ids, positions, + attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) + else: # Only mark static while compiling if is_torchair_compile: torch._dynamo.mark_static(input_ids) @@ -163,15 +184,11 @@ class NPUTorchairModelRunner(NPUModelRunner): inputs_embeds=None, **model_kwargs, ) - else: - if is_310p(): - converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) - hidden_states = super()._generate_dummy_run_hidden_states( - with_prefill, is_torchair_compile, input_ids, positions, - attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) return hidden_states def _convert_torch_format(self, kv_cache): + if self.enable_shared_expert_dp: + return super()._convert_torch_format(kv_cache) kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND) return kv_cache @@ -189,6 +206,8 @@ class NPUTorchairModelRunner(NPUModelRunner): def _capture_model(self): """Override from NPUModelRunner to use torchair graph capture.""" + if self.enable_shared_expert_dp: + return super()._capture_model() # TODO(NeverRaR): Calling graph_capture(device=self.device) in # torchair graph capture can cause some issues, so now we just # temporarily split the codepath for the two different graph patterns. @@ -228,6 +247,8 @@ class NPUTorchairModelRunner(NPUModelRunner): self.new_kv_cache_bytes) def _use_aclgraph(self) -> bool: + if self.enable_shared_expert_dp: + return super()._use_aclgraph() return False def _check_batch_sizes_consistency(self) -> None: @@ -253,10 +274,10 @@ class NPUTorchairModelRunner(NPUModelRunner): ) def _update_graph_pad_size(self, with_prefill, graph_pad_size): - if not with_prefill: - self.graph_pad_size = graph_pad_size - else: + if with_prefill or self.enable_shared_expert_dp: super()._update_graph_pad_size(with_prefill, graph_pad_size) + else: + self.graph_pad_size = graph_pad_size def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, @@ -266,7 +287,9 @@ class NPUTorchairModelRunner(NPUModelRunner): input_ids, positions, num_input_tokens, with_prefill, padded_num_tokens_across_dp) - if not with_prefill: + if with_prefill or self.enable_shared_expert_dp: + return input_ids, positions + else: input_ids = self.input_ids[:padded_num_tokens_across_dp] positions = self.positions[:padded_num_tokens_across_dp] return input_ids, positions @@ -276,6 +299,13 @@ class NPUTorchairModelRunner(NPUModelRunner): input_ids, positions, intermediate_tensors, inputs_embeds): + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + + if self.enable_shared_expert_dp: + return super()._generate_process_reqs_hidden_states( + attn_metadata, with_prefill, padded_num_tokens_across_dp, + input_ids, positions, intermediate_tensors, inputs_embeds) model_kwargs = { "kv_caches": self.kv_caches, "attn_metadata": attn_metadata @@ -332,21 +362,22 @@ class NPUTorchairModelRunner(NPUModelRunner): communication_adaptation_310p() config = torchair.CompilerConfig() - if get_ascend_config().torchair_graph_config.mode: - config.mode = get_ascend_config().torchair_graph_config.mode - config.experimental_config.frozen_parameter = True + if self.ascend_config.torchair_graph_config.mode: + config.mode = self.ascend_config.torchair_graph_config.mode + config.experimental_config.frozen_parameter = \ + self.ascend_config.torchair_graph_config.enable_frozen_parameter # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to # disable it on 300I Duo platform now. config.experimental_config.tiling_schedule_optimize = not is_310p() config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize + self.ascend_config.torchair_graph_config.enable_view_optimize torch.npu.set_compile_mode(jit_compile=False) if not self.use_cached_npu_graph: npu_backend = torchair.get_npu_backend(compiler_config=config) self.torchair_compiled_model = torch.compile( self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + dynamic=not self.ascend_config.use_sfa, + fullgraph=True, backend=npu_backend) return self.torchair_compiled_model else: @@ -368,8 +399,8 @@ class NPUTorchairModelRunner(NPUModelRunner): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + dynamic=not self.ascend_config.use_sfa, + fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, ge_cache=False) @@ -396,10 +427,16 @@ class NPUTorchairModelRunner(NPUModelRunner): f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." ) - def check_torchair_graph_batch_sizes(self): + def update_torchair_graph_batch_sizes(self): # return graph_batch_sizes according to the max number of tokens # first pad according to the number of requests - if len(self.torchair_graph_batch_sizes) == 0: + if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs + self.torchair_graph_batch_sizes = [self.max_num_reqs] + logger.warning( + "is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs]" + ) + elif len(self.torchair_graph_batch_sizes) == 0: self.torchair_graph_batch_sizes = [1, self.max_num_reqs] else: self.torchair_graph_batch_sizes = sorted( @@ -420,27 +457,47 @@ class NPUTorchairModelRunner(NPUModelRunner): for graph_batch_size in self.torchair_graph_batch_sizes ] - # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` + # NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size` + # Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same + # on all EP ranks + if get_ascend_soc_version( + ) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel: + self._align_graph_size_divisible_by_tp_size() + + def _align_graph_size_divisible_by_tp_size(self): tp_size = self.parallel_config.tensor_parallel_size - if self.parallel_config.enable_expert_parallel: - new_graph_batch_sizes = [] - for graph_batch_size in self.torchair_graph_batch_sizes: - cur_graph_batch_size = (graph_batch_size + tp_size - - 1) // tp_size * tp_size - if cur_graph_batch_size not in new_graph_batch_sizes and \ - cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: - new_graph_batch_sizes.append(cur_graph_batch_size) - elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ - and self.decode_token_per_req > 1: - logger.warning( - f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", - f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." - ) + new_graph_batch_sizes = [] + for graph_batch_size in self.torchair_graph_batch_sizes: + cur_graph_batch_size = (graph_batch_size + tp_size - + 1) // tp_size * tp_size + # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size, + # Both adapter multi-dp and FIA operator + if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1: + cur_graph_batch_size = (tp_size * graph_batch_size) \ + // math.gcd(tp_size, graph_batch_size) + if cur_graph_batch_size not in new_graph_batch_sizes and \ + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: + new_graph_batch_sizes.append(cur_graph_batch_size) + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ + and self.decode_token_per_req > 1: + logger.warning( + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." + ) + new_max_num_reqs = max(new_graph_batch_sizes) + if self.max_num_reqs != new_max_num_reqs: + logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}") + self.max_num_reqs = new_max_num_reqs + self.scheduler_config.max_num_seqs = new_max_num_reqs + + if new_graph_batch_sizes != self.torchair_graph_batch_sizes: + logger.warning( + f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}." + ) self.torchair_graph_batch_sizes = new_graph_batch_sizes def _build_drafter_prepare_inputs_torchair_param(self): - return True - - def get_dp_padding(self, num_tokens): - """Override from NPUModelRunner to get dp padding""" - return 0, None + if self.enable_shared_expert_dp: + return super()._build_drafter_prepare_inputs_torchair_param() + else: + return True diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py new file mode 100644 index 0000000..8dc6a68 --- /dev/null +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -0,0 +1,1330 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata +from vllm_ascend.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFATorchairBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA_TORCHAIR" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFATorchairMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFATorchairMetadataBuilder + + #NOTE: is that ok? + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendSFATorchairImpl + + +@dataclass +class AscendSFATorchairPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class TorchairChunkedContextMetadata: + # New for SFA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] # Check!! + seq_lens: list[int] # Check!! + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + sin: torch.Tensor + cos: torch.Tensor + chunked_context: Optional[TorchairChunkedContextMetadata] = None + + +@dataclass +class AscendSFATorchairDecodeMetadata: + # Input positions for rotrary embeddings since for SFA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendSFATorchairMetadata: + """Metadata for SFACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for SFA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendSFATorchairDecodeMetadata] = None + prefill: Optional[AscendSFATorchairPrefillMetadata] = None + enable_dbo_across_dp: bool = False + is_prefill: bool = False + is_decode: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendSFABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendSFATorchairMetadata"]: + """Split metadata for multi-stream with AscendSFATorchairMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendSFATorchairMetadata, + ) + + +M = TypeVar("M", bound=AscendSFATorchairMetadata) + + +class AscendSFATorchairMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFATorchairMetadata] = None): + self.metadata_cls: Optional[AscendSFATorchairMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFATorchairMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 SFA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + # For torch air graph mode we treat spec decoding as decode. + if self.torchair_graph_enabled: + if num_tokens - num_spec_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + # For eager mode we treat spec decoding as chunked prefill. + else: + if num_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] + + def build_torchair_graph_dummy( + self, + common_attn_metadata: TorchairCommonAttentionMetadata, + ) -> AscendSFATorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = [0] * num_reqs + input_positions = torch.zeros(num_tokens, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_tokens, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + sin = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + cos = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 + # cumsum here. + # actual_seq_lengths_q = torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_tokens]).to(torch.int32).npu() + # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).to(torch.int32).npu() + actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( + torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req ############## + decode_metadata = AscendSFATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=1, + attn_mask=common_attn_metadata.spec_attn_mask, + # actual_seq_lengths_q=torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_reqs]).to(torch.int32).npu(), + actual_seq_lengths_q=actual_seq_lengths_q, + # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu(), + sin=sin, + cos=cos, + ) + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=0, + attn_mask=common_attn_metadata.attn_mask, + attn_state=attn_state, + prefill=None, + decode=decode_metadata, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_tables=block_table, + is_prefill=False, + is_decode=True) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFATorchairMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + # check CPU operation here + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + is_prefill = False + is_decode = False + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + actual_query_lens = torch.tensor( + query_lens[tokens_start:], + dtype=torch.int32).npu() # int64->int32 + query_lens_prefill_sfa = torch.cumsum(actual_query_lens, + dim=0).to(torch.int32).npu() + seq_lens_prefill_sfa = torch.tensor(seq_lens, + dtype=torch.int32).npu() + prefill_metadata = AscendSFATorchairPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens_prefill_sfa, + seq_lens=seq_lens_prefill_sfa, + context_lens=seq_lens[tokens_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + is_prefill = True + + decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size + use_torchair_graph = graph_pad_size != -1 + if num_decodes > 0: + # Check here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( + torch.int32).npu() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() + # input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + num_token_pad_size = 0 + if use_torchair_graph and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + num_reqs_pad_size = 0 + if graph_pad_size != 0: + pad_value = 0 + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size + else: + padded_seq_lens = seq_lens.tolist() + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)).npu() + seq_lens_list = padded_seq_lens + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_reqs_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_reqs + num_reqs_pad_size, block_table) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + + # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).npu() + # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu() + actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( + torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req + # MTP ignored + # actual_seq_lengths_q = self.pad_actual_seq_len_q( + # num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + # common_attn_metadata) + else: + seq_lens_list = seq_lens.tolist() + # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) + batch_size = num_decode_tokens + num_token_pad_size + if actual_seq_lengths_q[-1] != batch_size \ + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + actual_seq_lengths_q[-1] = batch_size + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + padded_token_num = input_positions.shape[0] + actual_seq_lengths_q = torch.arange( + 1, + (padded_token_num // common_attn_metadata.decode_token_per_req) + + 1).to(torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req + decode_metadata = AscendSFATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + is_decode = True + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_prefill=is_prefill, + is_decode=is_decode) + + def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + # return actual_seq_lengths_q + return torch.Tensor(actual_seq_lengths_q).to(torch.int32).npu() + + +class PrefillSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + + +class DecodeSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + # nope_cache: Optional[torch.Tensor] = None + # rope_cache: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + bsz: Optional[int] = None + + +class AscendSFATorchairImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.decoder_layer = kwargs.get('decoder_layer', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + if self.q_a_proj is not None: + self.q_b_proj = self.q_proj + else: + self.q_b_proj = None + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.enable_prefetch + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + if ascend_config.torchair_graph_config.enabled: + self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[ + 0] + self.actual_seq_length = torch.arange(1, self.graph_batch_size + + 1).to(torch.int32).npu() + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + # indexer param + self.dim = self.indexer.dim + self.n_heads: int = self.indexer.n_heads # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.index_topk: int = self.indexer.index_topk # 2048 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + self.softmax_scale = self.indexer.softmax_scale + + # Adapt torch air graph mode with spec decoding. + speculative_config = vllm_config.speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + self.cp_size = 1 + + if self.q_a_proj is not None: + self.prefix = self.q_a_proj.prefix + else: + self.prefix = 0 + self.debug_layer_idx = int(self.prefix.split(".")[2]) + self.layers = vllm_config.model_config.hf_config.num_hidden_layers + self.first_k_dense_replace = vllm_config.model_config.hf_config.first_k_dense_replace + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO: + self._process_weights_for_fused_mlapo(act_dtype) + + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): + kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data.clone() + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()), + dim=-1) + wd_qkv = wd_qkv.t().contiguous() + wd_qkv = transdata(wd_qkv, + block_size=(16, 32)).unsqueeze(0).contiguous() + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + + kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, + self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat( + (kv_a_proj_deq_scl, self.q_a_proj.deq_scale.clone()), + dim=-1).contiguous() + + kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias.clone() + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, + self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat( + (kv_a_proj_qt_bias, self.q_a_proj.quant_bias.clone()), + dim=-1).contiguous() + + wu_q = self.q_proj.weight.data.clone() + wu_q = wu_q.t().reshape(self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + -1) + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) + wu_q = wu_q.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), + -1) + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + + qb_deq_scl = self.q_proj.deq_scale.data.clone() + qb_deq_scl = qb_deq_scl.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) + self.qb_deq_scl = qb_deq_scl.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + qb_qt_bias = self.q_proj.quant_bias.data.clone() + qb_qt_bias = qb_qt_bias.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) + self.qb_qt_bias = qb_qt_bias.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + self.gamma0 = self.decoder_layer.input_layernorm.weight.data + self.beta0 = self.decoder_layer.input_layernorm.bias.data + self.gamma1 = self.q_a_layernorm.weight.data + self.beta1 = self.q_a_layernorm.bias.data + self.gamma2 = self.kv_a_layernorm.weight.data + self.quant_scale0 = self.q_a_proj.input_scale.data + self.quant_offset0 = self.q_a_proj.input_offset.data + self.quant_scale1 = self.q_proj.input_scale.data + self.quant_offset1 = self.q_proj.input_offset.data + + def _sfa_decode_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + bsz = hidden_states.shape[0] + cos_shape = attn_metadata.decode.cos.shape + cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) + sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) + ctkv_scale = torch.tensor([1], + dtype=hidden_states.dtype, + device=hidden_states.device) + q_nope_scale = torch.tensor([1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + decode_q_nope, _, decode_q_pe, _ = torch_npu.npu_mla_process( + hidden_states, + self.gamma0, + self.beta0, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.kv_b_proj_w_k, + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping.flatten(), + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=ctkv_scale, + q_nope_scale=q_nope_scale, + cache_mode_opt="krope_ctkv", + quant_mode_opt="per_tensor_quant_asymm", + ) + decode_k_nope = kv_cache[0] + decode_k_pe = kv_cache[1] + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + hidden_states = self.decoder_layer.input_layernorm(hidden_states) + decode_kq = self.q_a_proj(hidden_states) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + + topk_indices = self.indexer_select(hidden_states, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + is_prefill=False) + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeSFAPreprocessResult( + q_nope=decode_q_nope, + q_pe=decode_q_pe, + topk_indices=topk_indices, + query_states=query_states, + key_states=key_states, + bsz=bsz, + ) + return decode_preprocess_res + + def forward( + self, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + + if attn_metadata.prefill is not None: + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + bsz = 1 + + hidden_states_prefill = hidden_states + prefill_slot_mapping = attn_metadata.slot_mapping + prefill_kq = self.q_a_proj(hidden_states_prefill) # q down + prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm + prefill_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_prefill) # c_kv + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + prefill_kv_no_split = get_tp_group().all_gather( + prefill_kv_no_split, + 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + # prefill_q_c = q_c[ + # num_decode_tokens:num_actual_tokens] + + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + # prefill_kv_no_split = kv_no_split[ + # num_decode_tokens:num_actual_tokens] + # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] + prefill_qr = prefill_q_c + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + prefill_qr = get_tp_group().all_gather( + prefill_qr, + 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + + prefill_q = self.q_b_proj(prefill_qr) + prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) + prefill_q_nope, prefill_q_pe = torch.split( + prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + prefill_q_nope = prefill_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + prefill_q_nope = (torch.matmul(prefill_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(-1, self.num_heads, + self.kv_lora_rank)) + prefill_q_pe = prefill_q_pe.unsqueeze(2) + + # stream2 kv + + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + cos_q, sin_q = cos, sin + + prefill_q_pe = torch_npu.npu_interleave_rope( + prefill_q_pe, cos_q, sin_q) # BNSD + prefill_q_pe = prefill_q_pe.squeeze(2) #BSH + # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? + + prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) + prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + prefill_latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + self.kv_a_layernorm.weight, + cos.view(-1, 1, 1, self.qk_rope_head_dim), + sin.view(-1, 1, 1, self.qk_rope_head_dim), + prefill_slot_mapping.to(torch.int64), + rope_cache, + nope_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA") + + topk_indices = self.indexer_select(x=hidden_states_prefill, + qr=prefill_qr, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + is_prefill=True) + query_states = (prefill_q_nope, prefill_q_pe) + key_states = (prefill_k_nope, prefill_k_pe) + q_nope, q_pe = query_states + k_nope, k_rope = key_states + prefill_metadata = attn_metadata.prefill + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=prefill_metadata.block_table, + actual_seq_lengths_query=prefill_metadata.query_lens, + actual_seq_lengths_kv=prefill_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul( + slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + # o_proj_input[num_decode_tokens:] = attn_output + output[...] = self.o_proj(attn_output, is_force_scatter=True) + return output + + elif attn_metadata.decode is not None: + if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO: + prep_res = self._sfa_decode_preprocess(hidden_states, kv_cache, + attn_metadata, + need_gather_q_kv) + q_nope, q_pe = prep_res.query_states + k_nope, k_rope = prep_res.key_states + topk_indices = prep_res.topk_indices + else: + q_len = 1 + hidden_states_decode = hidden_states + decode_kq = self.q_a_proj(hidden_states_decode) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + decode_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_decode) # c_kv + # self.actual_seq_length = torch.arange(1,self.graph_batch_size+1).to(torch.int32).npu() + + # decode_q_c = q_c[:num_decode_tokens] + decode_slot_mapping = attn_metadata.slot_mapping + + decode_q = self.q_b_proj(decode_q_c) + bsz, _ = decode_q.shape + decode_q = decode_q.view(bsz, self.num_heads, 1, + self.qk_head_dim) # [16, 16, 1, 192] + decode_q_nope, decode_q_pe = torch.split( + decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) # [..., 128/64] + decode_q_nope = decode_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + decode_q_nope = (torch.matmul( + decode_q_nope, self.kv_b_proj_w_k).transpose(1, 0).view( + bsz, q_len, self.num_heads, self.kv_lora_rank)) + + # stream2 kv + key_cache = kv_cache[0] + value_cache = kv_cache[1] + cos = attn_metadata.decode.cos # [16, 1, 1, 64] + sin = attn_metadata.decode.sin + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze( + 1) + decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + decode_kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + decode_slot_mapping.to(torch.int64), + value_cache, + key_cache, + c_kv_scale=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode='PA') # adapter NZ + # nz_block_size = 16 + # KVCACHE_NZ_DIM = 16 + # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) + # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) + decode_q_pe = torch_npu.npu_interleave_rope( + decode_q_pe, cos, sin) # BNSD + + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + topk_indices = self.indexer_select(hidden_states_decode, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + is_prefill=False) + + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_rope) + q_nope, q_pe = query_states + k_nope, k_rope = key_states + + decode_metadata = attn_metadata.decode + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.decode.block_table, + actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=decode_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.squeeze(1) + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul( + slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + output[...] = self.o_proj(attn_output) + return output + + def mla_epilog(self, + attn_output: torch.Tensor = None, + absorb: bool = False): + # TODO: + attn_output = self.o_proj(attn_output) + return attn_output + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + is_prefill: bool = True, + ): + if attn_metadata.prefill is not None: + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + elif attn_metadata.decode is not None: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + k_proj = get_tp_group().all_gather( + k_proj, 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights = self.weights_proj(x) + if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + weights = get_tp_group().all_gather( + weights, 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + actual_seq_lengths_query = None + actual_seq_lengths_key = None + block_table = None + if attn_metadata.prefill is not None: + actual_seq_lengths_query = attn_metadata.prefill.query_lens + actual_seq_lengths_key = attn_metadata.prefill.seq_lens + + block_table = attn_metadata.prefill.block_table + elif attn_metadata.decode is not None: + actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q + actual_seq_lengths_key = attn_metadata.decode.seq_lens.to( + torch.int32) + + block_table = attn_metadata.decode.block_table + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def trans_rope_weight(weight, rope_dim): + weight_1 = weight[..., -rope_dim::2, :].contiguous() + weight_2 = weight[..., -rope_dim + 1::2, :].contiguous() + weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2) + + return weight.contiguous() + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad))) + nz_mat = torch.permute( + torch.reshape( + nd_mat, + (r // block_size[0], block_size[0], c // block_size[1], + block_size[1]), + ), + [2, 0, 1, 3], + ) + nz_mat = torch.reshape( + nz_mat, + (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + return nz_mat diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py index 85f2fb4..dbee800 100644 --- a/vllm_ascend/torchair/torchair_worker.py +++ b/vllm_ascend/torchair/torchair_worker.py @@ -32,28 +32,28 @@ class NPUTorchairWorker(NPUWorker): """Override determine_available_memory to use cached torchair kv_cache_bytes.""" available_kv_cache_memory = super().determine_available_memory() - - if get_ascend_config( - ).torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist( - ): - old_kv_cache_bytes = read_kv_cache_bytes_from_file( - torch.distributed.get_rank()) - if 0 < old_kv_cache_bytes <= available_kv_cache_memory: - logger.info( - f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" - ) - self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes - return old_kv_cache_bytes - else: - logger.info( - "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" - ) - delete_torchair_cache_file() - bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE - available_kv_cache_memory -= bytes_floating_tolerance - logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") - self.model_runner.new_kv_cache_bytes = available_kv_cache_memory - + ascend_config = get_ascend_config() + if ascend_config.enable_shared_expert_dp: + return available_kv_cache_memory + if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: + if check_kv_cache_bytes_cache_exist(): + old_kv_cache_bytes = read_kv_cache_bytes_from_file( + torch.distributed.get_rank()) + if 0 < old_kv_cache_bytes <= available_kv_cache_memory: + logger.info( + f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" + ) + self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes + return old_kv_cache_bytes + else: + logger.info( + "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" + ) + delete_torchair_cache_file() + bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE + available_kv_cache_memory -= bytes_floating_tolerance + logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") + self.model_runner.new_kv_cache_bytes = available_kv_cache_memory return available_kv_cache_memory def init_device(self): diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 13d5879..668a7e7 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -165,6 +165,11 @@ def register_torchair_model(): "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" ) + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" + ) + ModelRegistry.register_model( "Qwen2ForCausalLM", "vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM") @@ -180,20 +185,31 @@ def register_torchair_model(): def torchair_quant_method_register(): - from vllm_ascend.quantization.quantizer import \ - SUPPORT_ASCEND_QUANTIZER_TYPE - from vllm_ascend.torchair.quantization.torchair_quantizer import ( - TorchairW4A8DYNAMICQuantizer, TorchairW8A8DYNAMICQuantizer) + from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP + from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( + TorchairAscendW4A8DynamicFusedMoEMethod, + TorchairAscendW4A8DynamicLinearMethod) + from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( + TorchairAscendW8A8DynamicFusedMoEMethod, + TorchairAscendW8A8DynamicLinearMethod) - SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer - SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer + ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ + "linear"] = TorchairAscendW8A8DynamicLinearMethod + ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ + "moe"] = TorchairAscendW8A8DynamicFusedMoEMethod + ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ + "linear"] = TorchairAscendW4A8DynamicLinearMethod + ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ + "moe"] = TorchairAscendW4A8DynamicFusedMoEMethod def torchair_ops_patch(): + from vllm_ascend.ops.activation import AscendSiluAndMul + from vllm_ascend.ops.layernorm import AscendRMSNorm from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + from vllm_ascend.torchair.ops import (torchair_activation, + torchair_layernorm) from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( deepseek_rope_init_func, native_rope_deepseek_forward, qwen_rope_init_func, rope_forward) @@ -203,3 +219,6 @@ def torchair_ops_patch(): AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign] AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign] + + AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] + AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index adab490..805fd57 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -21,13 +21,13 @@ import atexit import functools import math import os -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import torch -import torch_npu # noqa: F401 # noqa: F401 +import torch_npu # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event from vllm.logger import logger @@ -40,15 +40,9 @@ if TYPE_CHECKING: else: VllmConfig = None -# NOTE: Currently, we can only capture 1920 graphs at most, -# due to the limitation of ACL graph. This number is bounded by -# the number of streams, which is 2048, we save 128 streams -# as a buffer. -# Maximum number of graphs that can be captured by ACL Graph -MAX_CAPTURE_SIZE = 1920 - ASCEND_QUANTIZATION_METHOD = "ascend" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] +REGISTERED_ASCEND_OPS = {} ACL_FORMAT_FRACTAL_ND = 2 ACL_FORMAT_FRACTAL_NZ = 29 @@ -186,7 +180,7 @@ def try_register_lib(lib_name: str, lib_info: str = ""): def enable_custom_op(): """ - Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. + Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device(). """ global _CUSTOM_OP_ENABLED @@ -291,6 +285,14 @@ def get_max_hidden_layers(hf_config) -> int: def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: """Update ACL graph capture sizes based on hardware limitations""" + # NOTE: Currently, we can only capture 1800 graphs at most, + # due to the limitation of ACL graph. This number is bounded by + # the number of streams, which is 2048, we save 248 streams + # as a buffer. + # Maximum number of graphs that can be captured by ACL Graph + # TODO: Find out whether we need to solve allreduce function + MAX_CAPTURE_SIZE = 1800 + # Store original configuration and temporarily clear it compilation_config = vllm_config.compilation_config original_sizes, compilation_config.cudagraph_capture_sizes = \ @@ -304,6 +306,12 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: num_hidden_layers = get_max_hidden_layers(hf_config) parallel_config = vllm_config.parallel_config + # Calculate maximum supported batch sizes considering model architecture + resources_per_graph = num_hidden_layers + 1 + if vllm_config.speculative_config is not None: + draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config + resources_per_graph += draft_model_hf_config.num_hidden_layers + 1 + # TODO: Find out whether we need to take into account the pp_size num_comm_groups = sum(size > 1 for size in [ parallel_config.data_parallel_size, @@ -313,13 +321,22 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': # TODO: Find out whether we need to take into account the pp_size parallel_factor = 1 + num_comm_groups + int( - parallel_config.enable_expert_parallel) + parallel_config.enable_expert_parallel) + int( + vllm_config.additional_config.get( + "multistream_overlap_shared_expert", False)) + if is_moe_model(vllm_config): + parallel_factor += (parallel_config.data_parallel_size > 1) + else: + # When AIV mode is enabled, the allreduce operator of the dense + # layer model will occupy additional streams, which are buffered here. + MAX_CAPTURE_SIZE = MAX_CAPTURE_SIZE - parallel_factor * resources_per_graph + # Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device # Assume the following case: # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 - max_num_batch_sizes = math.floor( - MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor) + max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / + resources_per_graph / parallel_factor) logger.info( "Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes) @@ -335,8 +352,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 max_num_batch_sizes = math.floor( - (MAX_CAPTURE_SIZE - num_comm_groups * 40) / - (num_hidden_layers + 1) / (1 + num_comm_groups * 2)) + (MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / + (1 + num_comm_groups * 2)) logger.info( "Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes) @@ -473,10 +490,10 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): return False -def register_ascend_customop(): +def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): """Register Ascend CustomOP - NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, + NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, and ensure this will execute after model config is initilazed. """ global _ASCEND_CUSTOMOP_IS_REIGISTERED @@ -484,43 +501,49 @@ def register_ascend_customop(): return from vllm.model_executor.custom_op import CustomOp + from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul - from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, - AscendMlpMergedColumnParallelLinear, - AscendMlpRowParallelLinear) + from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, + AscendSharedFusedMoE) + from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm, + AscendQuantRMSNorm, AscendRMSNorm) + from vllm_ascend.ops.linear import (AscendColumnParallelLinear, + AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, + AscendRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) - CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU") - CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, - name="SiluAndMul") - CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding, - name="RotaryEmbedding") - CustomOp.register_oot( - _decorated_op_cls=AscendDeepseekScalingRotaryEmbedding, - name="DeepseekScalingRotaryEmbedding") - CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding, - name="VocabParallelEmbedding") - CustomOp.register_oot(_decorated_op_cls=AscendParallelLMHead, - name="ParallelLMHead") - CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor, - name="LogitsProcessor") - if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: - CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear, - name="ColumnParallelLinear") - CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear, - name="RowParallelLinear") - CustomOp.register_oot( - _decorated_op_cls=AscendMlpMergedColumnParallelLinear, - name="MergedColumnParallelLinear") - from vllm_ascend.ops.layernorm import AscendRMSNorm - CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm") + global REGISTERED_ASCEND_OPS + REGISTERED_ASCEND_OPS = { + "QuickGELU": AscendQuickGELU, + "SiluAndMul": AscendSiluAndMul, + "RotaryEmbedding": AscendRotaryEmbedding, + "ColumnParallelLinear": AscendColumnParallelLinear, + "RowParallelLinear": AscendRowParallelLinear, + "MergedColumnParallelLinear": AscendMergedColumnParallelLinear, + "QKVParallelLinear": AscendQKVParallelLinear, + "DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding, + "VocabParallelEmbedding": AscendVocabParallelEmbedding, + "ParallelLMHead": AscendParallelLMHead, + "LogitsProcessor": AscendLogitsProcessor, + "RMSNorm": AscendRMSNorm, + "GemmaRMSNorm": AscendGemmaRMSNorm, + "FusedMoE": AscendFusedMoE, + "SharedFusedMoE": AscendSharedFusedMoE, + "MultiHeadLatentAttention": AscendMultiHeadLatentAttention, + } - from vllm_ascend.ops.common_fused_moe import AscendFusedMoE - CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") + if vllm_config is not None and \ + vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm + + for name, op_cls in REGISTERED_ASCEND_OPS.items(): + CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True @@ -556,3 +579,74 @@ def get_ascend_soc_version(): def lmhead_tp_enable() -> bool: return get_ascend_config().lmhead_tensor_parallel_size is not None + + +def oproj_tp_enable() -> bool: + return get_ascend_config().oproj_tensor_parallel_size is not None + + +def mlp_tp_enable() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE + + +def matmul_allreduce_enable() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE + + +def dense_optim_enable() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE + + +def enable_sp(vllm_config=None) -> bool: + if vllm_config is None: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + return ( + vllm_config.compilation_config.pass_config.enable_sequence_parallelism + or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM) + + +def is_moe_model(vllm_config: VllmConfig): + config = vllm_config.model_config.hf_config + return any('experts' in key.lower() for key in config.to_dict()) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C_ascend.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors") + + +def npu_stream_switch(target_stream: torch.npu.Stream, + *, + enabled: bool = True): + """ + Switch to the target stream if enabled is True. + Otherwise, do nothing. + """ + if not enabled: + return nullcontext() + assert target_stream is not None + return torch.npu.stream(target_stream) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py new file mode 100644 index 0000000..307eb83 --- /dev/null +++ b/vllm_ascend/worker/block_table.py @@ -0,0 +1,312 @@ +from typing import Optional, Union + +import numpy as np +import torch +from vllm.distributed import get_dcp_group +from vllm.utils import cdiv + + +class BlockTable: + + def __init__(self, + block_size: int, + max_num_reqs: int, + max_num_blocks_per_req: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + kernel_sizes: Union[list[int], None] = None): + self.max_num_reqs = max_num_reqs + self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens + self.pin_memory = pin_memory + self.device = device + self.physical_block_size = block_size + # If kernel_sizes is None or [0], use physical block size (no splitting) + if kernel_sizes is None or kernel_sizes == [0]: + self.block_size = block_size + self.logical_block_size = block_size + self.blocks_per_phys_block = 1 + self.use_hybrid_blocks = False + else: + # Find the first kernel size that divides physical_block_size evenly + selected_kernel_size = None + for kernel_size in kernel_sizes: + if kernel_size > 0 \ + and self.physical_block_size % kernel_size == 0: + selected_kernel_size = kernel_size + break + + if selected_kernel_size is None: + raise ValueError( + f"None of the kernel sizes {kernel_sizes} can divide " + f"physical block size {self.physical_block_size} evenly") + + self.block_size = selected_kernel_size + self.logical_block_size = selected_kernel_size + self.blocks_per_phys_block = (self.physical_block_size // + self.logical_block_size) + if self.blocks_per_phys_block > 1: + self.use_hybrid_blocks = True + else: + self.use_hybrid_blocks = False + + if self.use_hybrid_blocks: + logical_table_size = (max_num_blocks_per_req * + self.blocks_per_phys_block) + else: + logical_table_size = max_num_blocks_per_req + + self.block_table = torch.zeros( + (max_num_reqs, logical_table_size), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, logical_table_size), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.kernel_sizes = kernel_sizes + + def append_row( + self, + block_ids, + row_idx: int, + ) -> None: + if not block_ids: + return + block_ids = np.array(block_ids) + if self.use_hybrid_blocks: + block_ids = self._convert_physical_to_logical_blocks(block_ids) + + num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] += num_blocks + + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + def swap_row(self, src: int, tgt: int) -> None: + num_blocks_src = self.num_blocks_per_row[src] + num_blocks_tgt = self.num_blocks_per_row[tgt] + self.num_blocks_per_row[src] = num_blocks_tgt + self.num_blocks_per_row[tgt] = num_blocks_src + + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with an interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // virtual_block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = (req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + + logical_block_idx) + + block_numbers = self.block_table_np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calculate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calculate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping_np[:req_indices.shape[0]] = np.where( + mask, slot_mapping, -1) + else: + assert self.kernel_sizes is not None + if self.block_size == self.kernel_sizes[0]: + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // self.block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = ( + req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + logical_block_idx) + + block_numbers = self.block_table_np.ravel( + )[block_table_indices] + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) + + def commit_block_table(self, num_reqs: int) -> None: + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def commit_slot_mapping(self, num_tokens: int) -> None: + self.slot_mapping[:num_tokens].copy_( + self.slot_mapping_cpu[:num_tokens], non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + def _convert_physical_to_logical_blocks( + self, physical_blocks: np.ndarray) -> np.ndarray: + """Convert physical block IDs to logical block IDs.""" + if not self.use_hybrid_blocks: + return physical_blocks + + # Create logical block IDs by splitting each physical block + logical_blocks: list[int] = [] + for phys_block in physical_blocks: + # Convert physical block to multiple logical blocks + # Physical block 1 becomes logical blocks + # [1*split_ratio, 1*split_ratio+1, ...] + # But we need to account for the fact that block 0 is special + base_logical = phys_block * self.blocks_per_phys_block + logical_blocks.extend( + range(base_logical, base_logical + self.blocks_per_phys_block)) + + return np.array(logical_blocks, dtype=np.int32) + + def get_device_tensor(self) -> torch.Tensor: + """Returns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np + + +class MultiGroupBlockTable: + """The BlockTables for each KV cache group.""" + + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0, + kernel_sizes: Optional[list[list[int]]] = None) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + + if kernel_sizes is None: + kernel_sizes = [[0]] * len(block_sizes) + # Ensure kernel_sizes matches block_sizes length + elif len(kernel_sizes) == 1 and len(block_sizes) > 1: + kernel_sizes = kernel_sizes * len(block_sizes) + elif len(kernel_sizes) != len(block_sizes): + raise ValueError( + f"kernel_sizes length ({len(kernel_sizes)}) must match " + f"block_sizes length ({len(block_sizes)})") + + # Use zip to pair block_sizes with kernel_sizes one-to-one + self.block_tables = [ + BlockTable( + block_size, max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens), max_num_batched_tokens, + pin_memory, device, kernel_size_list) + for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + ] + + def append_row(self, block_ids: tuple[list[int], ...], + row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) + + def move_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.move_row(src, tgt) + + def swap_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.swap_row(src, tgt) + + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + for block_table in self.block_tables: + block_table.compute_slot_mapping(req_indices, positions) + + def commit_block_table(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit_block_table(num_reqs) + + def commit_slot_mapping(self, num_tokens: int) -> None: + for block_table in self.block_tables: + block_table.commit_slot_mapping(num_tokens) + + def clear(self) -> None: + for block_table in self.block_tables: + block_table.clear() + + def __getitem__(self, idx: int) -> "BlockTable": + """Returns the BlockTable for the i-th KV cache group.""" + return self.block_tables[idx] diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py deleted file mode 100644 index 479ef1d..0000000 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ /dev/null @@ -1,398 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import os - -import torch -import torch.nn as nn -from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) -from vllm.distributed.parallel_state import get_pp_group -from vllm.logger import logger -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import supports_multimodal -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.v1.sample.metadata import SamplingMetadata - -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata - -PADDING_SLOT_ID = -1 - - -class EagleProposer: - - def __init__(self, - vllm_config: VllmConfig, - device: torch.device, - runner=None): - self.vllm_config = vllm_config - self.speculative_config = vllm_config.speculative_config - self.draft_model_config = self.speculative_config.draft_model_config - self.method = self.speculative_config.method - self.runner = runner - self.model_config = vllm_config.model_config - self.dtype = vllm_config.model_config.dtype - self.max_model_len = vllm_config.model_config.max_model_len - self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.device = device - # We need to get the hidden size from the draft model config because - # the draft model's hidden size can be different from the target model's - # hidden size (e.g., Llama 3.3 70B). - self.hidden_size = self.draft_model_config.get_hidden_size() - - self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) - - # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) - self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, - device=device, - dtype=torch.int32) - mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) - self.attn_mask_len = min(self.model_config.max_model_len, - int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder(self.attn_mask_len, - self.dtype) - - def _make_attention_mask( - self, - seq_lens, - position, - ) -> torch.Tensor: - return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, position, self.dtype, self.device) - - def propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - device = cu_num_tokens.device - cu_num_tokens = cu_num_tokens.cpu() - block_table = block_table.cpu() - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - target_positions = target_positions.cpu() - if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) - assert target_hidden_states.shape[-1] == self.hidden_size - - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids[0] - - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc[:batch_size + 1], - query_start_loc_cpu=self.runner.query_start_loc_cpu[:batch_size + - 1], - seq_lens_cpu=self.runner.seq_lens_cpu, - max_query_len=max_query_len, - num_reqs=batch_size, - num_actual_tokens=num_tokens, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=target_slot_mapping, - positions=target_positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - ) - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - attn_metadata = self.runner.attn_metadata_builder.build( - common_attn_metadata, self.runner.model) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens - # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions.to(device) - self.hidden_states[:num_tokens] = target_hidden_states - attn_metadata.block_tables = block_table.to(device) - with set_ascend_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], - ) - sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids = logits.argmax(dim=-1) - - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # Generate the remaining draft tokens. - draft_token_ids_tensor = torch.zeros( - (self.num_speculative_tokens, *draft_token_ids.shape), - dtype=draft_token_ids.dtype) - draft_token_ids_tensor[0] = draft_token_ids - - positions_cpu = target_positions[last_token_indices].cpu().to( - torch.int64) - hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - - if self.num_speculative_tokens > 2: - raise ValueError("Speculative tokens > 2 are not supported yet.") - - attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill - for now_speculative in range(self.num_speculative_tokens - 1): - # Update the inputs. - # cast to int32 is crucial when eagle model is compiled. - # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_tensor[now_speculative].to(device) - positions_cpu += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions_cpu >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, - positions_cpu) - clamped_positions = clamped_positions_cpu.to(device) - - # TODO: Increment the sequence lengths. - - attn_metadata.seq_lens += 1 - # TODO: Consider max model length. - # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - # self.max_model_len) - # For the requests that exceed the max model length, we set the - # TODO: sequence length to 1 to minimize their overheads in attention. - - # Compute the slot mapping. - block_numbers = (clamped_positions_cpu // self.block_size) - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - slot_mapping_cpu = (block_ids * self.block_size + - clamped_positions_cpu % self.block_size) - - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - slot_mapping_cpu.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) - # NOTE: ASCEND slot_mapping must on cpu - attn_metadata.slot_mapping = slot_mapping_cpu.to( - torch.int32).to(device) - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - self.hidden_states[:batch_size] = hidden_states - positions = positions_cpu.to(device) - attn_mask = self._make_attention_mask( - seq_lens=attn_metadata.seq_lens, - position=positions, - ) - attn_metadata.attn_mask = attn_mask - attn_metadata.block_tables = block_table.to(device) - # Run the model. - with set_ascend_forward_context(attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): - - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], - ) - hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) - - # TODO(wenlong): get more than one token for tree attention - draft_token_ids = logits.argmax(dim=-1) - draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() - - # [batch_size, num_speculative_tokens] - draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) - return draft_token_ids - - @staticmethod - def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, - num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - - # [a - n1, b - n2, c - n3] -> - # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - cu_num_tokens = torch.zeros_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_target_query_lens.device, - ) - BLOCK_SIZE = 1024 - prepare_eagle_input_sequential( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) - return cu_num_tokens, token_indices - - def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) - - draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - - self.attn_layer_names = list(draft_attn_layer_names) - self.attn_layer_name = next(iter(draft_attn_layer_names)) - # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1: - logger.info( - "The EAGLE head shares the same vocab embedding" \ - " with the target model." - ) - self.model.model.embed_tokens = target_model.model.embed_tokens - else: - logger.info( - "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." - ) - - # share lm_head with the target model if needed - # some model definition do not define lm_head explicitly - # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - if supports_multimodal(target_model): - self.model.lm_head = target_model.get_language_model().lm_head - else: - self.model.lm_head = target_model.lm_head - - @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - ) -> None: - with set_ascend_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): - self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - ) - - -def prepare_eagle_input_sequential(out_tensor: torch.Tensor, - cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, - block_size: int): - num_programs = len(cu_num_tokens) - 1 - for pid in range(num_programs): - start_pos = cu_num_tokens[pid].item() - end_pos = cu_num_tokens[pid + 1].item() - num_tokens = end_pos - start_pos - index_start = cu_query_lens[pid].item() - num_blocks = int( - torch.ceil(torch.tensor(num_tokens / block_size)).item()) - - for i in range(num_blocks): - offset_tensor = torch.arange(0, - block_size, - dtype=torch.int32, - device=out_tensor.device) - global_start_offset = i * block_size - target_indices = torch.tensor( - start_pos + global_start_offset, - dtype=torch.int32, - device=out_tensor.device) + offset_tensor - values_to_store = torch.tensor( - index_start, dtype=torch.int32, - device=out_tensor.device) + offset_tensor - mask = (target_indices >= start_pos) & \ - (target_indices < end_pos) & \ - (offset_tensor < num_tokens) - out_tensor[target_indices[mask]] = values_to_store[mask] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7a9fe1b..9281dd7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -19,11 +19,16 @@ import copy import gc -import math +import itertools import time +from collections import defaultdict +from collections.abc import Iterator from contextlib import contextmanager, nullcontext +from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast +from multiprocessing import Manager +from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, + Union, cast) import numpy as np import numpy.typing as npt @@ -33,10 +38,13 @@ import torch.distributed as dist import torch.nn as nn from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -45,7 +53,8 @@ from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, is_global_first_rank) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import supports_transcription @@ -55,52 +64,66 @@ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LazyLoader, cdiv, is_pin_memory_available) + LazyLoader, cdiv, get_dtype_size, + is_pin_memory_available) +from vllm.utils.jsontree import json_map_leaves +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheGroupSpec, + KVCacheSpec, MambaSpec) +# yapf: enable +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, +from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.ascend_forward_context import (MoECommType, + set_ascend_forward_context) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import (AscendAttentionState, - AscendMetadata) -from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + set_graph_params, + update_attn_params) +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ + D2DExpertWeightLoader +from vllm_ascend.eplb.core.eplb_worker import EplbProcess +from vllm_ascend.eplb.eplb_updator import EplbUpdator +from vllm_ascend.eplb.utils import model_register +from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform +from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler -from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata -from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata +from vllm_ascend.spec_decode import get_spec_decode_method +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, get_ascend_soc_version, is_310p, lmhead_tp_enable, vllm_version_is) -from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer -from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): - from vllm.v1.outputs import DraftTokenIds -else: - DraftTokenIds = None - if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -120,6 +143,13 @@ if is_310p(): else: ACL_FORMAT = ACL_FORMAT_FRACTAL_ND +if not vllm_version_is("0.10.2"): + from vllm.v1.kv_cache_interface import UniformTypeKVCacheSpecs + from vllm.v1.outputs import PoolerOutput +else: + from vllm.sequence import PoolerOutput + UniformTypeKVCacheSpecs = None + @dataclass class GraphCaptureContext: @@ -158,6 +188,53 @@ def graph_capture(device: torch.device): yield graph_capture_context +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.npu.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self._async_copy_ready_event = torch.npu.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.npu.current_stream() + with torch.npu.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) + self._async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self._async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + class NPUModelRunner(LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): @@ -175,10 +252,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - self.max_num_reqs = self.scheduler_config.max_num_seqs + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, + decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device + if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: + self.prefetch_stream = torch.npu.Stream(device=device) + else: + self.prefetch_stream = None self.dtype = self.model_config.dtype if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: # TODO: drop the env config to use ascend sampler by default @@ -189,19 +273,20 @@ class NPUModelRunner(LoRAModelRunnerMixin): from vllm.v1.sample.sampler import Sampler self.sampler = Sampler() + self.reorder_batch_threshold: Optional[int] = None # Lazy initialization, these will be set after __init__ self.kv_caches: List[torch.Tensor] = [] - # TODO: remove Dict[str, Dict[int, torch.Tensor]] type after 0.10.1.1 - self.encoder_cache: Union[Dict[str, Dict[int, torch.Tensor]], - Dict[str, torch.Tensor]] = {} + self.attn_groups: list[list[AttentionGroup]] = [] + self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None + self.runner_only_attn_layers: set[str] = set() - ascend_config = get_ascend_config() - if ascend_config.ascend_scheduler_config.enabled: + self.ascend_config = get_ascend_config() + if self.ascend_config.ascend_scheduler_config.enabled: self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled else: self.chunked_prefill_enabled = True @@ -211,6 +296,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype] + # use_hybrid_blocks: if hybrid blocks is used. + self.use_hybrid_blocks: bool = False + self.need_accepted_tokens: bool = False self.is_multimodal_model = self.model_config.is_multimodal_model self.is_pooling_model = self.model_config.pooler_config is not None @@ -219,58 +307,50 @@ class NPUModelRunner(LoRAModelRunnerMixin): (self.max_num_tokens, self.model_config.get_hidden_size()), dtype=self.dtype, device=self.device) - # Set up Attention - self.attn_backend = get_attn_backend( - 0, - self.dtype, - None, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - vllm_config, device) - self.attn_mask_builder = AttentionMaskBuilder( - self.model_config.max_model_len, self.dtype) + if vllm_version_is("0.10.2"): + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + use_sfa=self.ascend_config.use_sfa) + else: + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sfa=self.ascend_config.use_sfa) + if torch.version.cann.startswith("8.3"): + self.attn_mask_builder = AttentionMaskBuilder( + self.scheduler_config.max_num_batched_tokens, self.dtype, + self.device) + else: + self.attn_mask_builder = AttentionMaskBuilder( + self.model_config.max_model_len, self.dtype) # Set up speculative decoding. - self.use_aux_hidden_state_outputs = False - self.use_spec_decode = False self.spec_attn_mask = None - self.use_eagle = False self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None - self.actual_seq_lengths_q = [] + self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: - self.use_spec_decode = True spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num - self.actual_seq_lengths_q = [ - len for len in - range(self.decode_token_per_req, self.max_num_tokens + - 1, self.decode_token_per_req) - ] self.spec_attn_mask = torch.triu(torch.ones(2048, 2048, dtype=torch.bool), diagonal=1).to(self.device) if get_pp_group().is_last_rank: - if self.speculative_config.method == "ngram": - self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method in ["eagle", "eagle3"]: - self.use_eagle = True - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore - if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True - elif self.speculative_config.method == 'deepseek_mtp': - self.drafter = MtpProposer(self.vllm_config, self) - else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + self.drafter = get_spec_decode_method( + self.speculative_config.method, self.vllm_config, + self.device, self) self.rejection_sampler = AscendRejectionSampler() # Persistent batch. @@ -286,6 +366,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) self.uses_mrope = self.model_config.uses_mrope # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -366,12 +449,103 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer - self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size + # NOTE: Technically, MC2 can have 512 tokens each rank, but this will consume too much memory. The formula is: + # ((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + (maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2 + # so we have to limit the MC2 tokens to save memory, should fix this in the future. + self.mc2_tokens_capacity = 512 self.reserved_mc2_mask = torch.zeros( self.mc2_tokens_capacity, dtype=torch.bool, device=self.device, ) + self.dynamic_eplb = self.ascend_config.dynamic_eplb + if self.dynamic_eplb: + self.is_eplb_warmuped = False + self.eplb_loader = D2DExpertWeightLoader() + self.manager = Manager() + self.shared_dict = self.manager.dict({ + "expert_map": None, + "moe_load": None, + "expert_maps": None + }) + self.eplb_process = EplbProcess(shared_dict=self.shared_dict, + policy_type=1, + enable_d2d=True) + self.process = self.eplb_process._launch_process() + ascend_config = get_ascend_config() + self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader, + self.eplb_process, self.process) + + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = torch.npu.Stream() if \ + self.use_async_scheduling else None + # Input Batch + # NOTE(Chen): Ideally, we should initialize the input batch inside + # `initialize_kv_cache` based on the kv cache config. However, as in + # https://github.com/vllm-project/vllm/pull/18298, due to some unknown + # reasons, we have to initialize the input batch before `load_model`, + # quantization + weight offloading will fail otherwise. As a temporary + # solution, we initialize the input batch here, and re-initialize it + # in `initialize_kv_cache` if the block_sizes here is different from + # the block_sizes in the kv cache config. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, + kernel_block_sizes=None, + ) + self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + # Bfloat16 torch tensors cannot be directly cast to a numpy array, so + # if a bfloat16 buffer is needed without a corresponding numpy array, + # don't bother instantiating the numpy array. + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) + + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = (torch.cat( + [ + output_token_ids, + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), + ], + dim=1) == -1).int().argmax(-1).cpu().numpy() + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager @@ -380,8 +554,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -390,17 +563,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # and handling the second as a new request. for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) - else: - for mm_hash in scheduler_output.free_encoder_mm_hashes: - self.encoder_cache.pop(mm_hash, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove @@ -437,11 +601,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) + backward_kwargs = {} + if vllm_version_is("0.10.2"): + backward_kwargs["mm_kwargs"] = new_req_data.mm_kwargs + backward_kwargs["mm_hashes"] = new_req_data.mm_hashes + backward_kwargs["mm_positions"] = new_req_data.mm_positions + else: + backward_kwargs["mm_features"] = new_req_data.mm_features + self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -449,51 +619,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, - **({ - "mm_hashes": new_req_data.mm_hashes - } if not (vllm_version_is("0.10.1.1") - or vllm_version_is("0.10.1")) else { - "mm_hashes": None - }), + **backward_kwargs, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_item in self.requests[req_id].mm_kwargs: - mm_input = mm_item.get_data() - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.append( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.append( - mm_input["video_grid_thw"].tolist()) - if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.append( - mm_input["second_per_grid_ts"]) - if mm_input.get("audio_feature_lengths") is not None: - audio_feature_lengths.append( - mm_input["audio_feature_lengths"]) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + if vllm_version_is("0.10.2"): + self._init_mrope_positions_0102(self.requests[req_id]) + else: + self._init_mrope_positions(self.requests[req_id]) req_ids_to_add.append(req_id) @@ -586,14 +720,88 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() - + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + assert req_state.mm_features is not None + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _init_mrope_positions_0102(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + assert req_state.mm_kwargs is not None + for mm_item in req_state.mm_kwargs: + mm_input = mm_item.get_data() + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.append(mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.append(mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.append(mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.append(mm_input["audio_feature_lengths"]) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + hf_config = self.model_config.hf_config + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: - if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager: + # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in + # our case, we still need to sync the other two flags as well. So we need to + # include them in the all_reduce operation, and more over, we CANNOT skip it + # even if we are running in eager mode, which harms performance. + # FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here + # immediately once the other two flags are no longer needed. + if self.dp_size == 1: return num_tokens, None, with_prefill, enable_dbo # Sync num_tokens, with_prefill, enable_dbo across dp ranks @@ -649,152 +857,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): return False return True - def get_eagle_atten_dict( - self, - scheduler_output: "SchedulerOutput", - ) -> dict[str, Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, AscendMLATorchairMetadata]]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit_block_table(num_reqs) - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) - self.query_lens = torch.from_numpy(num_scheduled_tokens) - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) - - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - # NOTE(Chen): there is exactly one KV cache group that contains all - # attetnion layers in the model for now, so the current logic for - # getting attn_metadata is not related to kv_cache_group information. - # Will extend this part to support multiple KV cache groups later. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table = self.input_batch.block_table[kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) - - # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - - # Copy the tensors to the NPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - if self.uses_mrope: - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) - else: - # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) - - self.query_start_loc[:num_reqs + 1].copy_( - self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) - - # Fill unused with -1. Needed for reshape_and_cache - self.seq_lens[num_reqs:].fill_(0) - self.query_start_loc[num_reqs + 1:].fill_(-1) - - attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, - AscendMLATorchairMetadata]] = {} - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens_cpu, - num_reqs=num_reqs, - max_query_len=max_num_scheduled_tokens, - num_actual_tokens=total_num_scheduled_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=self.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=self.slot_mapping_cpu, - positions=self.positions, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, - decode_token_per_req=self.decode_token_per_req, - ) - attn_metadata_i = self.attn_metadata_builder.build( - common_attn_metadata, self.get_model()) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - return attn_metadata - def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. if isinstance(self.model, ACLGraphWrapper): @@ -829,12 +891,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. - if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: - return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, position, self.dtype, self.device) + if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: + if torch.version.cann.startswith("8.3"): + return self.attn_mask_builder.get_splitfuse_attn_mask() + else: + return self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, position, self.dtype, self.device) + # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens, default=0) + max_seq_len = max(seq_lens.max().item(), 0) return self.attn_mask_builder.get_attn_mask( max_seq_len, self.dtype, self.device) # Prefill with cache hit. @@ -900,34 +966,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): return # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + if vllm_version_is("0.10.2"): + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler_0102( + scheduler_output) else: - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() - for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - for mm_input_id in encoder_input_ids: - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append((req_id, mm_input_id, - req_state.mm_positions[mm_input_id])) - else: - for mm_input_id in encoder_input_ids: - # TODO remove this assert after 0.10.1.1 - assert req_state.mm_hashes is not None - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) - # Batch mm inputs as much as we can: if a request in the batch has - # multiple modalities or a different modality than the previous one, - # we process it separately to preserve item order. - # FIXME(ywang96): This is a hacky way to deal with multiple modalities - # in the same batch while still being able to benefit from batching - # multimodal inputs. The proper solution should be reordering the - # encoder outputs. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output) encoder_outputs = [] + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, device=self.device, @@ -950,79 +996,121 @@ class NPUModelRunner(LoRAModelRunnerMixin): for output in curr_group_outputs: encoder_outputs.append(output) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) - else: - for (mm_hash, pos_info), output in zip(mm_hashes_pos, - encoder_outputs): - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) + + # TODO: remove this once we drop support for vLLM 0.10.2 + def _batch_mm_kwargs_from_scheduler_0102( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return [], [] + # Batch the multi-modal inputs. + mm_kwargs = list[MultiModalKwargsItem]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + assert req_state.mm_hashes is not None + assert req_state.mm_kwargs is not None + assert req_state.mm_positions is not None + for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) + + return mm_kwargs, mm_hashes_pos + + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return [], [] + # Batch the multi-modal inputs. + mm_kwargs = list[MultiModalKwargsItem]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + assert req_state.mm_features is not None + for mm_input_id in encoder_input_ids: + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + + return mm_kwargs, mm_hashes_pos def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", ) -> list[torch.Tensor]: + + def _iter_mm_features(req_state: CachedRequestState): + if vllm_version_is("0.10.2"): + # legacy path (to be removed later) + assert req_state.mm_hashes is not None + assert req_state.mm_positions is not None + for mm_hash, pos_info in zip(req_state.mm_hashes, + req_state.mm_positions): + yield mm_hash, pos_info, getattr(pos_info, "is_embed", + None) + else: + assert req_state.mm_features is not None + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + yield mm_feature.identifier, pos_info, getattr( + pos_info, "is_embed", None) + mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions - if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): - mm_hashes = req_state.mm_hashes - for i, pos_info in enumerate(mm_positions): + + for mm_hash, pos_info, is_embed in _iter_mm_features(req_state): start_pos = pos_info.offset num_encoder_tokens = pos_info.length - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, - # num_computed_tokens + num_scheduled_tokens) and - # [start_pos, start_pos + num_encoder_tokens) if start_pos >= num_computed_tokens + num_scheduled_tokens: - # The encoder output is not needed in this step. break if start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. continue start_idx = max(num_computed_tokens - start_pos, 0) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] - else: - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens, - ) - assert start_idx < end_idx - # TODO remove this assert after 0.10.1.1 - assert mm_hashes is not None - mm_hash = mm_hashes[i] - encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx - if (is_embed := pos_info.is_embed) is not None: + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None, \ + f"Encoder cache miss for {mm_hash}." + + if is_embed is not None: is_embed = is_embed[start_idx:end_idx] mm_embeds_item = gather_mm_placeholders( @@ -1032,38 +1120,147 @@ class NPUModelRunner(LoRAModelRunnerMixin): mm_embeds.append(mm_embeds_item) return mm_embeds + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + NPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the NPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= (prev_index == flattened_index) + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the NPU first. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids_cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, + 0], + non_blocking=True) + return + # Upload the index tensors asynchronously + # so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor(flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to( + self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + self.input_ids.scatter_(dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0]) + + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + + if self.reorder_batch_threshold is not None: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata, - AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int, - torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, - Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, + int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor], int]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - self.attn_metadata_builder.reorder_batch(self.input_batch, - scheduler_output) # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) - num_valid_tokens = np.empty(num_reqs, dtype=np.int32) - max_num_scheduled_tokens = 0 - for i, req_id in enumerate(self.input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens[i] = num_tokens - num_valid_tokens[i] = num_tokens - \ - len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array([ + num_tokens - + len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32) if (self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]): @@ -1104,13 +1301,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - # Prepare positions + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - cu_num_tokens = np.cumsum(num_scheduled_tokens) - cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, - num_scheduled_tokens) - arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], @@ -1127,74 +1326,67 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) - self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() - self.positions[:num_input_tokens].copy_( - self.positions_cpu[:num_input_tokens], non_blocking=True) - positions_cpu = self.positions_cpu[:num_input_tokens] - positions = self.positions[:num_input_tokens] - self.query_lens = torch.from_numpy(num_scheduled_tokens) + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) - - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - - self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, - position=positions_cpu, - attn_state=attn_state) - self.attn_state = attn_state # type: ignore + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping( + req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc[:num_reqs + 1].copy_( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache - self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_( + self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, + position=positions_cpu, + attn_state=attn_state) + self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens_cpu, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=self.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=self.slot_mapping_cpu, - positions=self.positions, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, - enable_dbo_across_dp=enable_dbo, - is_only_prefill=bool(np.all(num_valid_tokens != 1)), - max_query_len=max_num_scheduled_tokens, - graph_pad_size=self.graph_pad_size, - decode_token_per_req=self.decode_token_per_req, - ) - attn_metadata = self.attn_metadata_builder.build( - common_attn_metadata, self.model) - if self.vllm_config.model_config.use_mla: - attn_metadata.num_input_tokens = num_input_tokens + attn_metadata: dict[str, Any] = {} # Prepare input_ids token_indices = (positions_np + @@ -1204,11 +1396,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Copy the tensors to the NPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) - # _prepare_inputs may reorder the batch, so we must gather multi - # modal outputs after that to ensure the correct order + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) @@ -1277,6 +1468,92 @@ class NPUModelRunner(LoRAModelRunnerMixin): spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[: + total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + ) + + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + if vllm_version_is("0.10.2"): + builder = attn_group.metadata_builder + else: + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder): + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens. + gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.model, + **extra_attn_metadata_args) + + if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: + attn_metadata_i.num_input_tokens = num_input_tokens + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i if lmhead_tp_enable(): max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs @@ -1287,7 +1564,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, - input_ids, inputs_embeds, intermediate_tensors) + input_ids, inputs_embeds, intermediate_tensors, + max_num_scheduled_tokens) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, maybe_padded_num_tokens, @@ -1301,6 +1579,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) + + if get_forward_context().sp_enabled: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens, @@ -1317,7 +1606,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): - if self.use_eagle: + if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE + or self.drafter.name == SpecDcodeType.EAGLE3): attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.SpecDecoding @@ -1338,26 +1628,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions - def _get_cumsum_and_arange( - self, - num_tokens: np.ndarray, - cumsum_dtype: Optional[np.dtype] = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Get the cumulative sum and batched arange of the given array. - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_tokens]) - """ - # Step 1. [2, 5, 3] -> [2, 7, 10] - cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) - total_num_tokens = cu_num_tokens[-1] - # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] - cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) - # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = self.arange_np[:total_num_tokens] - cumsums_offsets - - return cu_num_tokens, arange - def _calc_spec_decode_metadata( self, num_draft_tokens: np.ndarray, @@ -1506,77 +1776,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, - AscendMLATorchairMetadata], + attn_metadata: dict[str, Any], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: - if not self.use_spec_decode: + if not self.drafter: # Speculative decoding is not enabled. draft_token_ids = None - elif self.speculative_config.method == "ngram": - draft_token_ids = self._generate_ngram_token_ids( - valid_sampled_token_ids) - elif self.speculative_config.method == "eagle": - raise NotImplementedError("Eagle Is Not Supported Yet.") - elif self.speculative_config.method == "eagle3": - draft_token_ids = self._generate_eagle3_token_ids( + else: + draft_token_ids = self.drafter.generate_token_ids( valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, num_scheduled_tokens, - hidden_states, aux_hidden_states) - elif self.speculative_config.method == 'deepseek_mtp': - draft_token_ids = self._generate_mtp_token_ids( - valid_sampled_token_ids, sampling_metadata, scheduler_output, - spec_decode_metadata, positions, num_scheduled_tokens, - hidden_states, attn_metadata) + hidden_states, attn_metadata, aux_hidden_states) return draft_token_ids - def _pool_v010( - self, - hidden_states: torch.Tensor, - num_scheduled_tokens: int, - num_scheduled_tokens_np: np.ndarray, - finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None, - kv_connector_output: Optional["KVConnectorOutput"] = None, - ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" - - extracted_hidden_states = list( - torch.split(hidden_states[:num_scheduled_tokens], - num_scheduled_tokens_np.tolist())) - - pooling_metadata = self.input_batch.pooling_metadata - - raw_pooler_output = self.model.pooler( - hidden_states=extracted_hidden_states, - pooling_metadata=pooling_metadata) - - pooler_output: list[Optional[torch.Tensor]] = [] - seq_lens = self.seq_lens[:self.input_batch.num_reqs] - for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): - - if seq_len == prompt_len: - pooler_output.append(raw_output.data.cpu()) - else: - pooler_output.append(None) - extra_args = ({"kv_connector_output": kv_connector_output}) - modelrunner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, - **extra_args, - ) - return modelrunner_output - def _pool( self, hidden_states: torch.Tensor, @@ -1597,18 +1809,30 @@ class NPUModelRunner(LoRAModelRunnerMixin): device=hidden_states.device) seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] - # Pooling models D2H & synchronize occurs in pooler.py:build_output - raw_pooler_output = self.model.pooler( - hidden_states=hidden_states, pooling_metadata=pooling_metadata) + if vllm_version_is("0.10.2"): + # Pooling models D2H & synchronize occurs in pooler.py:build_output + raw_pooler_output = self.model.pooler( + hidden_states=hidden_states, pooling_metadata=pooling_metadata) + else: + model = cast(VllmModelForPooling, self.model) + raw_pooler_output = model.pooler( + hidden_states=hidden_states, + pooling_metadata=pooling_metadata, + ) + raw_pooler_output = json_map_leaves( + lambda x: x.to("cpu", non_blocking=True), + raw_pooler_output, + ) + torch.npu.synchronize() pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - - if seq_len == prompt_len: - pooler_output.append(raw_output.data) + if vllm_version_is("0.10.2"): + output = raw_output.data if seq_len == prompt_len else None else: - pooler_output.append(None) + output = raw_output if seq_len == prompt_len else None + pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1620,30 +1844,73 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_connector_output=kv_connector_output, ) - def _select_moe_comm_method(self, num_tokens: int) -> str: - soc_version = get_ascend_soc_version() + def _select_moe_comm_method(self, num_tokens: int, + with_prefill: bool) -> MoECommType: + """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all + are designed for expert parallelism. + 2. If expert parallel is enabled, we need to consider the soc version and the + number of tokens. This is based on the observation that all-gather is more + efficient than all-to-all when running on A2. - if num_tokens <= self.mc2_tokens_capacity: - moe_comm_method = "mc2" + a. For A2, we choose from MC2 and all-gather. + + b. For A3, we choose from MC2 and all-to-all. + + In both cases, we use MC2 when the number of tokens is smaller than + a its capacity threshold. + + Args: + num_tokens (int): The number of tokens in the current batch. + + Raises: + ValueError: If the soc version is unsupported. + + Returns: + MoECommType: The selected MoE communication method. + """ + soc_version = get_ascend_soc_version() + quant_type = getattr(self.vllm_config.model_config.hf_config, + 'moe_quantize', None) + model_type = self.vllm_config.model_config.hf_config.model_type + + if not self.parallel_config.enable_expert_parallel: + moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendSocVersion.A2}: - moe_comm_method = "allgather" + if (num_tokens <= self.mc2_tokens_capacity + and self.parallel_config.world_size_across_dp >= 16): + moe_comm_type = MoECommType.MC2 + else: + # Currently, w4a8_dynamic does not support allgatherep + if quant_type == "w4a8_dynamic": + moe_comm_type = MoECommType.ALLTOALL + else: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendSocVersion.A3}: - moe_comm_method = "alltoall" + moe_comm_type = (MoECommType.MC2 + if num_tokens <= self.mc2_tokens_capacity else + MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}") + if moe_comm_type == MoECommType.ALLGATHER and with_prefill: + moe_comm_type = MoECommType.NAIVE_MULTICAST + + # PanguProMoE only supports allgather + if model_type == "PanguProMoE": + moe_comm_type = MoECommType.ALLGATHER + if is_global_first_rank(): logger.debug(f"num_tokens: {num_tokens}, " - f"moe_comm_method: {moe_comm_method}") - - return moe_comm_method + f"moe_comm_type: {moe_comm_type}") + return moe_comm_type @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, torch.Tensor]: + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: @@ -1654,16 +1921,28 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + (attn_metadata, positions, num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors) = (self._prepare_inputs( - scheduler_output, intermediate_tensors)) + intermediate_tensors, + max_query_len) = (self._prepare_inputs(scheduler_output, + intermediate_tensors)) - moe_comm_method = self._select_moe_comm_method(num_input_tokens) + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + moe_comm_type = self._select_moe_comm_method(num_input_tokens, + self.with_prefill) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=False) + uniform_decode=uniform_decode) aclgraph_runtime_mode, batch_descriptor = \ self.aclgraph_dispatcher.dispatch(batch_descriptor) @@ -1676,11 +1955,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp, with_prefill=self.with_prefill, reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. - total_num_scheduled_tokens): + total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -1692,16 +1973,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output) aux_hidden_states = None - if self.use_aux_hidden_state_outputs: + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, aux_hidden_states = hidden_states - kv_connector_output = None - if finished_sending is not None or finished_recving is not None: - kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving) - else: - kv_connector_output = None + kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) finished_sending = None finished_recving = None with ProfileExecuteDuration().capture_async("post process"): @@ -1723,21 +2000,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): logits = None else: if self.input_batch.pooling_params: - if vllm_version_is("0.10.1.1") or vllm_version_is( - "0.10.1"): - return self._pool_v010( - hidden_states, - scheduler_output.total_num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving, kv_connector_output) - else: - return self._pool( - hidden_states, - scheduler_output.total_num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving, kv_connector_output) + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving, kv_connector_output) sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self._compute_logits_wrapper(sample_hidden_states, + None) if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), @@ -1790,6 +2060,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) discard_sampled_tokens_req_indices: list[int] = [] # TODO(woosuk): The following loop can be slow since it iterates over @@ -1807,6 +2079,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): generator.set_offset(generator.get_offset() - 4) discard_sampled_tokens_req_indices.append(i) + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + # NOTE: NPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1819,27 +2097,52 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() - # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1876,27 +2179,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): extra_args = ({"kv_connector_output": kv_connector_output}) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - spec_token_ids=self._draft_token_ids, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - **extra_args, - ) - else: - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - **extra_args, - ) + model_runner_output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + **extra_args, + ) durations = ProfileExecuteDuration().pop_captured_sync() if durations: @@ -1907,8 +2198,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + if not self.use_async_scheduling: + return model_runner_output - return model_runner_output + return AsyncNPUModelRunnerOutput( + model_runner_output=model_runner_output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: @@ -1930,8 +2230,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): # For the case of no forward caused by receiving remote kv, # one round of dummy inference is necessary # to prevent hang over the collective calls. - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = KVConnectorOutput( @@ -1965,12 +2263,54 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): - if skip_attn: - attn_metadata = None - else: - # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata - attn_metadata = None + def _build_attention_metadata(self, create_mixed_batch, num_reqs, + num_tokens, max_query_len, force_attention): + attn_metadata: Optional[dict[str, Any]] = None + + if force_attention: + attn_metadata = {} + + if create_mixed_batch: + raise NotImplementedError( + "force_attention=True is not supported for mixed batches.") + else: + seq_lens = self.model_config.max_model_len + self.seq_lens_np[:num_reqs] = seq_lens + self.seq_lens_np[num_reqs:] = 0 + + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_table_tensor = self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + decode_token_per_req=self.decode_token_per_req, + ) + + for attn_group in self.attn_groups[kv_cache_group_id]: + if vllm_version_is("0.10.2"): + builder = attn_group.metadata_builder + else: + builder = attn_group.get_metadata_builder() + attn_metadata_i = builder.build_for_graph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + return attn_metadata def _generate_dummy_run_hidden_states(self, with_prefill, @@ -1981,12 +2321,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) - if self.use_aux_hidden_state_outputs: + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states else: hidden_states = hidden_states - if self.use_spec_decode and isinstance(self.drafter, EagleProposer): - self.drafter.dummy_run(num_tokens) return hidden_states @torch.inference_mode() @@ -2001,18 +2339,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) -> torch.Tensor: # only support eager mode and piecewise graph now assert aclgraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - if force_attention: - raise RuntimeError( - "Capturing attention in aclgraph is unexpected, because full graph is not supported now" - ) # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) - moe_comm_method = self._select_moe_comm_method(num_tokens) + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using @@ -2030,7 +2364,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens - max_num_reqs = self.scheduler_config.max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. @@ -2059,12 +2392,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=np.int32) # Force dummy run on prefill stage when this node is deemed as kv producer. - if self.is_kv_producer: + if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True - attn_metadata = self._build_attention_metadata(with_prefill, - num_reqs, - skip_attn=True) + attn_metadata = self._build_attention_metadata( + with_prefill, + num_reqs, + num_tokens, + max_query_len, + force_attention, + ) + + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -2115,7 +2455,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32) def dummy_compute_logits(hidden_states): - return self.model.compute_logits( + return self._compute_logits_wrapper( hidden_states[dummy_indices], None) with set_ascend_forward_context( @@ -2126,10 +2466,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): with_prefill=with_prefill, in_profile_run=self.in_profile_run, reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model): hidden_states = self._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, @@ -2137,8 +2479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): if need_dummy_logits: dummy_compute_logits(hidden_states) - if self.speculative_config and self.speculative_config.method == "deepseek_mtp": - assert isinstance(self.drafter, MtpProposer) + if self.drafter: self.drafter.dummy_run( num_tokens=num_tokens, with_prefill=with_prefill, @@ -2147,6 +2488,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp) if need_dummy_logits: dummy_compute_logits(hidden_states) + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() return hidden_states @contextmanager @@ -2162,6 +2508,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): with self.set_in_profile_run(): hidden_states = self._dummy_run(self.max_num_tokens, with_prefill=True) + # MC2 will consume additional NPU memory. + # Therefore, we need to run the MC2 path once here to complete its initialization, + # allowing vLLM to correctly estimate the maximum memory required. + if self._select_moe_comm_method( + self.mc2_tokens_capacity, + with_prefill=True) == MoECommType.MC2: + self._dummy_run(self.mc2_tokens_capacity, with_prefill=True) + output = None if get_pp_group().is_last_rank: if self.is_pooling_model: @@ -2179,13 +2533,18 @@ class NPUModelRunner(LoRAModelRunnerMixin): logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task hidden_states = hidden_states[logit_indices] - output = self.model.compute_logits(hidden_states, None) + output = self._compute_logits_wrapper(hidden_states, None) NPUPlatform.synchronize() del hidden_states, output self.encoder_cache.clear() gc.collect() + def _compute_logits_wrapper(self, hidden_states, sampling_metadata): + if vllm_version_is("0.10.2"): + return self.model.compute_logits(hidden_states, sampling_metadata) + return self.model.compute_logits(hidden_states) + def _dummy_pooler_run_task( self, hidden_states: torch.Tensor, @@ -2200,8 +2559,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - hidden_states_list = list( - torch.split(hidden_states, num_scheduled_tokens_list)) req_num_tokens = num_tokens // num_reqs dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), @@ -2212,55 +2569,32 @@ class NPUModelRunner(LoRAModelRunnerMixin): dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - dummy_prompt_lens = torch.tensor( - [h.shape[0] for h in hidden_states_list], - device=self.device, - ) - dummy_metadata = PoolingMetadata( - prompt_lens=dummy_prompt_lens, - prompt_token_ids=dummy_token_ids, - pooling_params=[dummy_pooling_params] * num_reqs, - ) - try: - return model.pooler(hidden_states=hidden_states_list, - pooling_metadata=dummy_metadata) - except RuntimeError as e: - if 'out of memory' in str(e): - raise RuntimeError( - "NPU out of memory occurred when warming up pooler " - f"({task=}) with {num_reqs} dummy requests. Please try " - "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e - else: - raise e - else: - dummy_prompt_lens = torch.tensor( - num_scheduled_tokens_list, - device="cpu", - ) - dummy_metadata = PoolingMetadata( - prompt_lens=dummy_prompt_lens, - prompt_token_ids=dummy_token_ids, - pooling_params=[dummy_pooling_params] * num_reqs, - ) + dummy_prompt_lens = torch.tensor( + num_scheduled_tokens_list, + device="cpu", + ) + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) - try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) - except RuntimeError as e: - if 'out of memory' in str(e): - raise RuntimeError( - "CUDA out of memory occurred when warming up pooler " - f"({task=}) with {num_reqs} dummy requests. Please try " - "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e - else: - raise e + try: + return model.pooler(hidden_states=hidden_states, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e @torch.inference_mode() def _dummy_pooler_run( @@ -2272,18 +2606,30 @@ class NPUModelRunner(LoRAModelRunnerMixin): for task in self.get_supported_pooling_tasks(): # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = output.get_data_nbytes() + if vllm_version_is("0.10.2"): + output_size[task] = output.get_data_nbytes() + else: + output_size[task] = sum(o.nbytes for o in output) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0] return self._dummy_pooler_run_task(hidden_states, max_task) + def eplb_warmup(self): + if self.dynamic_eplb and not self.is_eplb_warmuped: + self.is_eplb_warmuped = True + self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + self.eplb_loader.set_adator(self.eplb_adaptor) + self.eplb_updator.set_adaptor(self.eplb_adaptor) + self.eplb_updator.warm_up_eplb() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) - + if self.dynamic_eplb: + model_register(self.model, self.model_config) if is_310p(): from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -2296,22 +2642,33 @@ class NPUModelRunner(LoRAModelRunnerMixin): module.weight.data) if self.drafter: logger.info("Loading drafter model...") - if isinstance(self.drafter, EagleProposer): - if self.use_aux_hidden_state_outputs: - self.drafter.load_model(self.model) - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: - self.drafter.load_model() + self.drafter.load_model(self.model) + if self.drafter.name == SpecDcodeType.EAGLE3: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + if self.lora_config: - self.model = self.load_lora_model(self.model, - self.model_config, - self.scheduler_config, - self.lora_config, - self.device) + if vllm_version_is("0.10.2"): + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) + else: + self.model = self.load_lora_model(self.model, + self.vllm_config, + self.device) logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) + # wrap the model with full graph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.update_stream = torch.npu.Stream() + set_graph_params(self.compilation_config.cudagraph_capture_sizes) + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + def _convert_torch_format(self, tensor): tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor @@ -2323,31 +2680,46 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - kv_caches: Dict[str, torch.Tensor] = {} + self.initialize_attn_backend(kv_cache_config) + self.use_hybrid_blocks = (len(self.attn_groups) > 1) + # NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`. + if vllm_version_is("0.10.2"): + self.need_accepted_tokens = any([ + isinstance( + self.kv_cache_config.kv_cache_groups[0].kv_cache_spec, + MambaSpec) for attn_group in self.attn_groups + ]) + else: + self.need_accepted_tokens = any([ + isinstance(attn_group[0].kv_cache_spec, MambaSpec) + for attn_group in self.attn_groups + ]) - def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: - data_ptr = tensor.data_ptr() - aligned_addr = (data_ptr + alignment - 1) // alignment * alignment - offset = (aligned_addr - data_ptr) // tensor.element_size() - return tensor[int(offset):] + self.may_reinitialize_input_batch(kv_cache_config) - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.block_size], - is_spec_decode=bool(self.vllm_config.speculative_config), - logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, - self.is_pooling_model, - self.vllm_config.model_config.logits_processors), - is_pooling_model=self.is_pooling_model, - ) + if self.ascend_config.is_deepseek_sfa: + kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( + kv_cache_config) + elif self.model_config.is_deepseek_mla: + kv_caches = self.initialize_kv_cache_tensors_deepseek_mla( + kv_cache_config) + else: + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + + def _align_memory(self, tensor: torch.Tensor, + alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + + def initialize_kv_cache_tensors_deepseek_sfa( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( @@ -2355,115 +2727,610 @@ class NPUModelRunner(LoRAModelRunnerMixin): "NPU.") kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator_dispatcher(): + if vllm_version_is("0.10.2"): + kv_cache_spec, group = group + else: + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue tensor_size = kv_cache_sizes[layer_name] - assert tensor_size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_size // kv_cache_spec.page_size_bytes + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr( + attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype - # `num_blocks` is the number of blocks the model runner can use. - # `kv_cache_config.num_blocks` is the number of blocks that - # KVCacheManager may allocate. - # Since different GPUs may have different number of layers and - # different memory capacities, `num_blocks` can be different on - # different GPUs, and `kv_cache_config.num_blocks` is set to - # the min of all `num_blocks`. Verify it here. - assert num_blocks >= kv_cache_config.num_blocks alignment = 2 * 1024 * 1024 + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_cache_shape = (num_blocks, block_size, num_kv_heads, + nope_dim) + rope_cache_shape = (num_blocks, block_size, num_kv_heads, + rope_dim) + #### k cache + # TODO(zzzzwwjj): wait transformers add these params + k_cache_shape = (num_blocks, block_size, 1, 128) + if self.vllm_config.kv_transfer_config is None: + # For no disaggregate pd scenario, allocate kv cache in normal way + rope_cache = torch.zeros(rope_cache_shape, + dtype=dtype, + device=self.device) + nope_cache = torch.zeros(nope_cache_shape, + dtype=dtype, + device=self.device) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) + + #### k cache + k_cache = torch.zeros(k_cache_shape, + dtype=dtype, + device=self.device) + k_cache = self._convert_torch_format(k_cache) + else: + + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + #### k cache + # TODO(zzzzwwjj): wait transformers add these params + k_allocate_shape = num_blocks * block_size * 1 * 128 + k_allocate_shape_alignment = k_allocate_shape + alignment + k_cache = torch.zeros(k_allocate_shape_alignment, + dtype=dtype, + device=self.device) + + nope_cache = self._align_memory( + nope_cache, + alignment)[:nope_allocate_shape].view(nope_cache_shape) + rope_cache = self._align_memory( + rope_cache, + alignment)[:rope_allocate_shape].view(rope_cache_shape) + k_cache = self._align_memory( + k_cache, + alignment)[:k_allocate_shape].view(k_cache_shape) + + kv_caches[layer_name] = (nope_cache, rope_cache, k_cache) + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def initialize_kv_cache_tensors_deepseek_mla( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator_dispatcher(): + if vllm_version_is("0.10.2"): + kv_cache_spec, group = group + else: + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + tensor_size = kv_cache_sizes[layer_name] + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + alignment = 2 * 1024 * 1024 + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_cache_shape = (num_blocks, block_size, num_kv_heads, + nope_dim) + rope_cache_shape = (num_blocks, block_size, num_kv_heads, + rope_dim) + if self.vllm_config.kv_transfer_config is None: + # For no disaggregate pd scenario, allocate kv cache in normal way + rope_cache = torch.zeros(rope_cache_shape, + dtype=dtype, + device=self.device) + nope_cache = torch.zeros(nope_cache_shape, + dtype=dtype, + device=self.device) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) + else: + + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + nope_cache = self._align_memory( + nope_cache, + alignment)[:nope_allocate_shape].view(nope_cache_shape) + rope_cache = self._align_memory( + rope_cache, + alignment)[:rope_allocate_shape].view(rope_cache_shape) + kv_caches[layer_name] = (nope_cache, rope_cache) + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # init kv cache tensors + kv_cache_raw_tensors: dict[str, Union[torch.Tensor, + Optional[torch.Tensor]]] = {} + # llmdatadist need the addr of cache tensor be aligned with 2M + alignment = 2 * 1024 * 1024 + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + # TODO: REFACTOR ME to sharing hybrid cache + for idx in range(len(kv_cache_tensor.shared_by)): + layer_name = kv_cache_tensor.shared_by[idx] + if "linear_attn" in layer_name: + # for mamba linear attention + for layer_name_inner in kv_cache_tensor.shared_by: + if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \ + layer_name_inner in kv_cache_raw_tensors.keys(): + continue + if self.vllm_config.kv_transfer_config is None: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + else: + cache_size_aligned = kv_cache_tensor.size + alignment + tensor = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + tensor = self._align_memory( + tensor, alignment)[:kv_cache_tensor.size] + kv_cache_raw_tensors[layer_name_inner] = tensor + elif "attn" in layer_name: + # for other attentions, e.g., self_attn, sliding window attn + if self.vllm_config.kv_transfer_config is None: + k_tensor = torch.zeros(kv_cache_tensor.size // 2, + dtype=torch.int8, + device=self.device) + v_tensor = torch.zeros(kv_cache_tensor.size // 2, + dtype=torch.int8, + device=self.device) + else: + cache_size = kv_cache_tensor.size // 2 + cache_size_aligned = kv_cache_tensor.size // 2 + alignment + k_tensor = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + v_tensor = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + k_tensor = self._align_memory(k_tensor, + alignment)[:cache_size] + v_tensor = self._align_memory(v_tensor, + alignment)[:cache_size] + kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor) + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator_dispatcher(): + if vllm_version_is("0.10.2"): + kv_cache_spec, group = group + else: + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): + raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name] + assert raw_k_tensor is not None + assert raw_v_tensor is not None + assert (raw_k_tensor.numel() + raw_v_tensor.numel() + ) % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel() + ) // kv_cache_spec.page_size_bytes + + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks + if self.vllm_config.additional_config.get( "kv_cache_dtype", None) == 'int8': - kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape( + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and self.use_hybrid_blocks: + block_size = attn_backend.get_supported_block_size()[0] + + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if self.model_config.is_deepseek_mla: - num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape - rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - nope_dim = head_size - rope_dim - nope_cache_shape = (num_blocks, block_size, - num_kv_heads, nope_dim) - rope_cache_shape = (num_blocks, block_size, - num_kv_heads, rope_dim) - if self.vllm_config.kv_transfer_config is None: - # For no disaggregate pd scenario, allocate kv cache in normal way - rope_cache = torch.zeros(rope_cache_shape, - dtype=dtype, - device=self.device) - nope_cache = torch.zeros(nope_cache_shape, - dtype=dtype, - device=self.device) - rope_cache = self._convert_torch_format(rope_cache) - nope_cache = self._convert_torch_format(nope_cache) - else: + k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:]) + k_cache = self._convert_torch_format(k_cache) + v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:]) + v_cache = self._convert_torch_format(v_cache) + kv_caches[layer_name] = (k_cache, v_cache) + elif isinstance(kv_cache_spec, MambaSpec): + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor is not None + assert raw_tensor.numel( + ) % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes - # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory - # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but - # we found there are also some exceptions during test, so we manual align those memory here, this part - # of code may consume 2M * 2 * elem_size memory every layer. - nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim - nope_allocate_shape_alignment = nope_allocate_shape + alignment - rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim - rope_allocate_shape_alignment = rope_allocate_shape + alignment + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks - nope_cache = torch.zeros( - nope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - rope_cache = torch.zeros( - rope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - nope_cache = align_memory( - nope_cache, - alignment)[:nope_allocate_shape].view( - nope_cache_shape) - rope_cache = align_memory( - rope_cache, - alignment)[:rope_allocate_shape].view( - rope_cache_shape) - kv_caches[layer_name] = (nope_cache, rope_cache) - else: - num_caches = kv_cache_shape[0] - kv_cache_list = [] - for i in range(num_caches): - cache_shape = kv_cache_shape[1:] - if self.vllm_config.kv_transfer_config is None: - kv_cache = torch.zeros(cache_shape, - dtype=dtype, - device=self.device) - kv_cache = self._convert_torch_format(kv_cache) - else: - cache_size = math.prod(cache_shape) - cache_size_aligned = cache_size + alignment - kv_cache = torch.zeros(cache_size_aligned, - dtype=dtype, - device=self.device) - kv_cache = align_memory( - kv_cache, - alignment)[:cache_size].view(cache_shape) - kv_cache_list.append(kv_cache) - kv_caches[layer_name] = tuple(kv_cache_list) + state_tensors = [] + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + kv_caches[layer_name] = state_tensors else: - # TODO: add new branches when introducing more types of - # KV cache specs. raise ValueError("Unknown KV cache spec type.") bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) - if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) + return kv_caches + + def may_reinitialize_input_batch(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + + # Generate kernel_block_sizes that matches each block_size + # For attention backends that support virtual block splitting, + # use the supported block sizes from the backend + # For other backends (like Mamba), use [0] (no splitting) + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # the backend. + try: + attn_groups = self.attn_groups[kv_cache_group_id] + except IndexError: + attn_groups = None + if attn_groups and self.use_hybrid_blocks: + # Use the backend's supported block size list + backend = attn_groups[0].backend + supported_sizes = backend.get_supported_block_size() + # If no specific sizes supported, use cache config + # block_size + kernel_block_size_list = (supported_sizes + if supported_sizes else + [self.cache_config.block_size]) + else: + # Fallback to cache config block_size if no backend found + kernel_block_size_list = [self.cache_config.block_size] + kernel_block_sizes.append(kernel_block_size_list) + else: + # This is likely Mamba or other non-attention cache, + # no splitting. + # NOTE: set kernel_block_sizes to 0 to disable slotmapping computation + # of mamba block. In this case, BlockTable.block_size will never equal + # to kernel_block_sizes[0] + kernel_block_sizes.append([0]) + if kernel_block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), + kernel_block_sizes=kernel_block_sizes, + ) + + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + + class AttentionGroupKey(NamedTuple): + attn_backend: type[AttentionBackend] + kv_cache_spec: KVCacheSpec + + def get_attn_backends_for_group( + kv_cache_group_spec: KVCacheGroupSpec, + ) -> dict[AttentionGroupKey, list[str]]: + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, + kv_cache_group_spec.layer_names) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in kv_cache_group_spec.layer_names: + attn_backend = layers[layer_name].get_attn_backend() + full_cls_name = attn_backend.full_cls_name() + layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ + layer_name] + key = (full_cls_name, layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey(attn_backend, + layer_kv_cache_spec) + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def get_attn_backends_for_layers( + layer_names: list[str] + ) -> dict[type[AttentionBackend], list[str]]: + """Get attention_backend for all attention layers + TODO: Only used in v0.10.2, drop me when 0.10.2 is dropped + """ + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in layer_names: + attn_backend = layers[layer_name].get_attn_backend() + key = attn_backend.full_cls_name() + attn_backends[key] = attn_backend + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def create_attn_groups_v0102( + attn_backends_map: dict[AttentionBackend, list[str]], + kv_cache_spec: KVCacheSpec, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for attn_backend, layer_names in attn_backends_map.items(): + attn_metadata_builder_i = attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builder_i, + layer_names) + attn_groups.append(attn_group) + return attn_groups + + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for (attn_backend, + kv_cache_spec), layer_names in attn_backends_map.items(): + attn_metadata_builders = [] + attn_metadata_builders.append(attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + )) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builders, + layer_names, kv_cache_spec) + attn_groups.append(attn_group) + return attn_groups + + if vllm_version_is("0.10.2"): + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) + self.attn_groups.append( + create_attn_groups_v0102(attn_backends, kv_cache_spec)) + else: + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + attn_backends = get_attn_backends_for_group( # type: ignore + kv_cache_group_spec) + self.attn_groups.append(create_attn_groups(attn_backends)) + + # Calculate reorder batch threshold (if needed) + self.calculate_reorder_batch_threshold() + + def _attn_group_iterator(self) -> Iterator[AttentionGroup]: + return itertools.chain.from_iterable(self.attn_groups) + + def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: + if not self.kv_cache_config.kv_cache_groups: + return + for attn_groups in self.attn_groups: + yield from attn_groups + + def _kv_cache_spec_attn_group_iterator_v0102( + self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + if not self.kv_cache_config.kv_cache_groups: + return + for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): + for attn_group in attn_groups: + yield self.kv_cache_config.kv_cache_groups[ + kv_cache_spec_id].kv_cache_spec, attn_group + + def _kv_cache_spec_attn_group_iterator_dispatcher(self): + if vllm_version_is("0.10.2"): + return self._kv_cache_spec_attn_group_iterator_v0102() + else: + return self._kv_cache_spec_attn_group_iterator() + + def calculate_reorder_batch_threshold(self) -> None: + """ + Check that if any backends reorder batches; that the reordering + is compatible (e.g., decode threshold is the same) + """ + for group in self._attn_group_iterator(): + if vllm_version_is("0.10.2"): + attn_metadata_builder_i = group.metadata_builder + else: + attn_metadata_builder_i = group.get_metadata_builder() + if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"): + # check that if any backends reorder batches; that the reordering + # is compatible (e.g., decode threshold is the same) + reorder_batch_threshold_i = ( + attn_metadata_builder_i.reorder_batch_threshold) + if reorder_batch_threshold_i is not None: + if self.reorder_batch_threshold is not None: + if reorder_batch_threshold_i != \ + self.reorder_batch_threshold: + raise ValueError( + f"Attention backend reorders decodes with " + f"threshold {reorder_batch_threshold_i} but other " + f"backend uses threshold " + f"{self.reorder_batch_threshold}") + else: + self.reorder_batch_threshold = reorder_batch_threshold_i def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -2474,23 +3341,37 @@ class NPUModelRunner(LoRAModelRunnerMixin): format. Layers that do not need KV cache are not included. """ - forward_ctx = self.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + use_sfa = self.ascend_config.use_sfa kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + if isinstance(attn_module, AscendMultiHeadLatentAttention): continue - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=self.block_size, + block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) + use_mla=use_mla, + use_sfa=use_sfa) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. @@ -2501,12 +3382,110 @@ class NPUModelRunner(LoRAModelRunnerMixin): raise ValueError( f"Unknown attention type: {attn_module.attn_type}") + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) + if len(mamba_layers) > 0: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) + return kv_cache_spec def initialize_aclgraph_capture(self) -> None: - # TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported - # Trigger aclgraph dispatching keys initialization here (after - # initializing attn backends). + min_ag_support = AttentionCGSupport.ALWAYS + min_ag_builder_name = None + + for attn_group in self._attn_group_iterator(): + if vllm_version_is("0.10.2"): + builder = attn_group.metadata_builder + else: + builder = attn_group.get_metadata_builder() + if builder.aclgraph_support.value < min_ag_support.value: + min_ag_support = builder.aclgraph_support + min_ag_builder_name = builder.__class__.__name__ + + # This is an imitation of compilation_config.splitting_ops_contain_attention() + splitting_ops_contain_attention = ( + self.compilation_config.splitting_ops is not None + and all(op in self.compilation_config.splitting_ops for op in [ + "vllm.unified_ascend_attention_with_output", + "vllm.mla_forward", + ])) + + # Flexible resolve the aclgraph mode + aclgraph_mode = self.compilation_config.cudagraph_mode + # check graph for mixed batch is supported + if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_ag_support != AttentionCGSupport.ALWAYS: + msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported " + f"with {min_ag_builder_name} backend (support: " + f"{min_ag_support})") + if min_ag_support == AttentionCGSupport.NEVER: + # if not supported any full graphs, just raise it. + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" + raise ValueError(msg) + + # attempt to resolve the full graph related mode + if splitting_ops_contain_attention: + msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" + aclgraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_AND_PIECEWISE) + else: + msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" + aclgraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_DECODE_ONLY) + logger.warning(msg) + + # check that if spec-decode + decode full-graphs is supported + if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 and min_ag_support.value + < AttentionCGSupport.UNIFORM_BATCH.value): + msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_ag_builder_name} (support: {min_ag_support})") + if splitting_ops_contain_attention: + msg += "; setting cudagraph_mode=PIECEWISE" + aclgraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE" + aclgraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) + + # double check that we can support full graph if they are requested + # even after automatic downgrades + if aclgraph_mode.has_full_cudagraphs() \ + and min_ag_support == AttentionCGSupport.NEVER: + raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not " + f"supported with {min_ag_builder_name} backend (" + f"support:{min_ag_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise") + self.aclgraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) @@ -2515,10 +3494,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): aclgraph_runtime_mode: CUDAGraphMode, uniform_decode: bool): assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \ - aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE] + aclgraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] # Only rank 0 should print progress bar during capture if is_global_first_rank(): + logger.info( + "Starting to capture ACL graphs for cases: %s, " + "mode: %s, uniform_decode: %s", compilation_cases, + aclgraph_runtime_mode.name, uniform_decode) compilation_cases = tqdm( compilation_cases, disable=not self.load_config.use_tqdm_on_load, @@ -2540,6 +3524,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): uniform_decode=uniform_decode) self._dummy_run(num_tokens, aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, uniform_decode=uniform_decode) def _capture_model(self): @@ -2561,10 +3546,38 @@ class NPUModelRunner(LoRAModelRunnerMixin): aclgraph_runtime_mode = aclgraph_mode.mixed_mode() compilation_cases = list(reversed(self.aclgraph_batch_sizes)) + + try: + self._capture_aclgraphs( + compilation_cases, + aclgraph_runtime_mode=aclgraph_runtime_mode, + uniform_decode=False) + except Exception as e: + logger.error( + f"ACLgraph sizes capture fail: {type(e).__name__}:\n" + "ACLgraph has insufficient available streams to capture the configured number of sizes. " + "Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n" + "Recommended solutions:\n" + "1. Manually configure the compilation_config parameter " + "with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n" + "2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n" + f"{str(e)}") + raise + + if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + aclgraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.aclgraph_batch_sizes if x <= max_num_tokens + and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) self._capture_aclgraphs( - compilation_cases, - aclgraph_runtime_mode=aclgraph_runtime_mode, - uniform_decode=False) + compilation_cases=compilation_cases_decode, + aclgraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) # Disable aclgraph capturing globally, so any unexpected aclgraph # capturing will be detected and raise an error after here. @@ -2590,193 +3603,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) - def _generate_ngram_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - # TODO(woosuk): Optimize. - draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require top-p, top-k, etc. - req_id = self.input_batch.req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - # Add sampled_token_ids to token_ids_cpu. - start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + num_sampled_ids - self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - assert isinstance(self.drafter, NgramProposer) - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :end_idx]) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) - return draft_token_ids - - def _generate_eagle3_token_ids(self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - scheduler_output: "SchedulerOutput", - spec_decode_metadata: SpecDecodeMetadata, - positions: torch.Tensor, - num_scheduled_tokens: int, - hidden_states: torch.Tensor, - aux_hidden_states: torch.Tensor = None): - assert isinstance(self.drafter, EagleProposer) - attn_metadata = self.get_eagle_atten_dict(scheduler_output) - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc - else: - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) - cu_num_tokens, token_indices = self.drafter.prepare_inputs( - eagle_attn_metadata.query_start_loc, num_rejected_tokens, - num_tokens) - target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] - - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids - - def _generate_mtp_token_ids( - self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - scheduler_output: "SchedulerOutput", - spec_decode_metadata: SpecDecodeMetadata, - positions: torch.Tensor, - num_scheduled_tokens: int, - hidden_states: torch.Tensor, - attn_metadata: Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, - AscendMLATorchairMetadata], - ): - assert isinstance(self.drafter, MtpProposer) - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - accepted_token_indices = None - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc - else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - cu_num_tokens, accepted_token_indices, target_token_ids, \ - target_positions, target_hidden_states, target_slot_mapping = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, - num_rejected_tokens, - self.input_ids[:num_scheduled_tokens], - positions[:num_scheduled_tokens], - hidden_states[:num_scheduled_tokens], - attn_metadata.slot_mapping[:num_scheduled_tokens], - is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(), - ) - - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - token_indices=accepted_token_indices) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids - def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, @@ -2839,7 +3665,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states, None) + logits = self._compute_logits_wrapper(prompt_hidden_states, None) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py deleted file mode 100644 index e8f369f..0000000 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ /dev/null @@ -1,439 +0,0 @@ -import types - -import torch -import torch.nn as nn -import torchair -import vllm.envs as envs_vllm -from torchair import patch_for_hcom -from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, get_layers_from_vllm_config, - set_current_vllm_config) -from vllm.forward_context import get_forward_context -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, set_default_torch_dtype) -from vllm.v1.sample.metadata import SamplingMetadata - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP -from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ - TorchairDeepSeekMTP -from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, - TorchairCommonAttentionMetadata) -from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable - - -class MtpProposer: - - def __init__( - self, - vllm_config: VllmConfig, - runner, - ): - self.vllm_config = vllm_config - self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens) - self.block_size = vllm_config.cache_config.block_size - self.hidden_size = vllm_config.model_config.get_hidden_size() - self.runner = runner - # persistent buffers for graph - self.input_ids = torch.zeros(self.runner.max_num_tokens, - dtype=torch.int32, - device=self.runner.device) - self.positions = torch.zeros(self.runner.max_num_tokens, - dtype=torch.int64, - device=self.runner.device) - self.hidden_states = torch.zeros( - (self.runner.max_num_tokens, self.hidden_size), - dtype=self.runner.dtype, - device=self.runner.device) - self.torchair_compiled_model = None # type: ignore - self.torchair_compiled_models = {} # type: ignore - self.torchair_graph_enabled = get_ascend_config( - ).torchair_graph_config.enabled - - @staticmethod - def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, - token_ids: torch.Tensor, - positions: torch.Tensor, - hidden_states: torch.Tensor, - slot_mapping: torch.Tensor, - is_torchair_graph: bool = False - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - if is_torchair_graph: - cu_num_tokens = cu_target_query_lens - relative_index = query_len_per_req - num_rejected_tokens - 1 - token_indices = cu_num_tokens[:-1] + relative_index - # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model - target_token_ids = token_ids - target_positions = positions - target_hidden_states = hidden_states - target_slot_mapping = slot_mapping - else: - cu_num_tokens = torch.empty_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() - token_indices = torch.zeros( - num_tokens, - dtype=torch.int32, - device=cu_num_tokens.device, - ) - - BLOCK_SIZE = 1024 - prepare_input_kernel( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) - target_token_ids = token_ids[token_indices] - target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = slot_mapping[token_indices] - return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping - - def propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, - sampling_metadata: SamplingMetadata, - token_indices=None) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - if token_indices is not None and self.torchair_graph_enabled: - last_token_indices = token_indices - - self.input_ids[last_token_indices] = next_token_ids - - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - # FIXME: reorder_batch() needs to be called before build() - # because fields of attn_metadata_builder needs to be updated. - # However, currently reorder_batch() takes input_batch and - # scheduler_output as arguments, we should probably refactor - # the method to use new data structures which are independent - # from input_batch and scheduler_output. - # self.runner.attn_metadata_builder.reorder_batch( - # input_batch=self.runner.input_batch, - # scheduler_output=self.runner.scheduler_output, - # ) - is_running_torchair = self.torchair_graph_enabled and \ - not self.runner.with_prefill - - if is_running_torchair: - num_input_tokens = self.runner.graph_pad_size - else: - num_input_tokens = num_tokens - - seq_lens = target_positions[last_token_indices] + 1 - seq_lens = seq_lens.int() - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=cu_num_tokens[:batch_size + 1], - query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), - seq_lens_cpu=seq_lens.cpu(), - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=target_slot_mapping, - positions=target_positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - graph_pad_size=self.runner.graph_pad_size, - decode_token_per_req=self.runner.decode_token_per_req, - ) - attn_metadata = self.runner.attn_metadata_builder.build( - common_attn_metadata, self.runner.get_model()) - - self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states - - if not self.torchair_graph_enabled: - # torch mode need to update num_tokens_across_dp - # TODO: adapt enable_dbo later - (num_input_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._sync_metadata_across_dp( - num_tokens, self.runner.with_prefill, False) - attn_metadata.slot_mapping = target_slot_mapping - else: - # torchair mode can reuse self.runner.num_tokens_across_dp - num_tokens_across_dp = self.runner.num_tokens_across_dp - with_prefill = self.runner.with_prefill - - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=num_tokens): - with ProfileExecuteDuration().capture_async('mtp_forward'): - model_kwargs = {} - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] - if is_running_torchair: - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_input_tokens) - hidden_states = torchair_compiled_model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - previous_hidden_states=self. - hidden_states[:num_input_tokens], - inputs_embeds=None, - intermediate_tensors=None, - spec_step_idx=0, - **model_kwargs) - else: - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - previous_hidden_states=self. - hidden_states[:num_input_tokens], - kv_caches=self.runner.kv_caches[-1:]) - - num_indices = last_token_indices.shape[0] - if lmhead_tp_enable(): - if not self.runner.with_prefill: - max_num_reqs_across_dp = num_input_tokens - else: - max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs - last_token_indices = nn.functional.pad( - last_token_indices, (0, max_num_reqs_across_dp - num_indices)) - - sample_hidden_states = hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if lmhead_tp_enable() and num_indices < logits.shape[0]: - logits = logits[:num_indices] - draft_token_ids = logits.argmax(dim=-1) - - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - def load_model(self) -> None: - loader = get_model_loader(self.vllm_config.load_config) - - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - target_device = self.vllm_config.device_config.device - - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - if self.torchair_graph_enabled: - self.model = TorchairDeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - else: - self.model = CustomDeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - - draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = next(iter(draft_attn_layer_names)) - - self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) - process_weights_after_loading(self.model, draft_model_config, - target_device) - - @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - skip_attn: bool = False, - num_reqs: int = 0, - num_tokens_across_dp=None) -> None: - if not self.torchair_graph_enabled: - # TODO: adapt enable_dbo later - (num_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._sync_metadata_across_dp(num_tokens, - with_prefill, False) - is_running_torchair = self.torchair_graph_enabled and \ - not with_prefill - - if is_running_torchair: - skip_attn = False - if skip_attn: - attn_metadata = None - else: - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=num_reqs, - num_actual_tokens=1, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - decode_token_per_req=self.runner.decode_token_per_req, - ) - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - common_attn_metadata) - - input_ids = self.input_ids[:num_tokens] - positions = self.positions[:num_tokens] - previous_hidden_states = self.hidden_states[:num_tokens] - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=0): - if is_running_torchair: - assert attn_metadata is not None - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(previous_hidden_states) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static(attn_metadata.decode.input_positions) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(get_forward_context().mc2_mask) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - torch._dynamo.mark_static(attn_metadata.decode.attn_mask) - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_tokens) - torchair_compiled_model( - input_ids=input_ids, - positions=positions, - previous_hidden_states=previous_hidden_states, - inputs_embeds=None, - intermediate_tensors=None, - attn_metadata=attn_metadata, - kv_caches=self.runner.kv_caches[-1:], - spec_step_idx=0) - else: - self.model(input_ids=input_ids, - positions=positions, - previous_hidden_states=previous_hidden_states) - - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ - -1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.runner.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - patch_for_hcom() - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.runner.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - cache_dir=TORCHAIR_CACHE_DIR, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] - - -# TODO Using torch instead of triton may result in poor performance -def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, block_size: int): - device = cu_query_lens.device - dtype = out_ptr.dtype - - offsets = torch.arange(block_size, device=device, dtype=dtype) - start_pos = cu_num_tokens[:-1] - end_pos = cu_num_tokens[1:] - num_tokens = end_pos - start_pos - - global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) - values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) - - mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) - - global_indices_flat = global_indices[mask] - values_flat = values[mask] - out_ptr[global_indices_flat] = values_flat diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index cbd25a8..d1ebd02 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -24,8 +24,9 @@ import numpy as np import torch from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, + MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values @@ -37,9 +38,9 @@ from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm_ascend.utils import vllm_version_is +from vllm_ascend.worker.block_table import MultiGroupBlockTable @dataclass @@ -47,10 +48,6 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_positions: list[PlaceholderRange] - # TODO: remove Optional after 0.10.1.1 - mm_hashes: Optional[list[str]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -62,6 +59,12 @@ class CachedRequestState: mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None + mm_features: Optional[list[MultiModalFeatureSpec]] = None + # for back-compatibility, will be removed in next major release + mm_kwargs: Optional[list[MultiModalKwargsItem]] = None + mm_positions: Optional[list[PlaceholderRange]] = None + mm_hashes: Optional[list[PlaceholderRange]] = None + lora_request: Optional[LoRARequest] = None def __post_init__(self): @@ -75,8 +78,18 @@ class CachedRequestState: @property @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargs]: - return [MultiModalKwargs([item]) for item in self.mm_kwargs] + def mm_inputs(self) -> list[MultiModalKwargsItems]: + if vllm_version_is("0.10.2"): + assert self.mm_kwargs is not None + return [ + MultiModalKwargsItems.from_seq([item]) + for item in self.mm_kwargs + ] + assert self.mm_features is not None + return [ + MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features + if f.data is not None + ] def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: @@ -88,18 +101,19 @@ class CachedRequestState: class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group - logitsprocs: Optional[LogitsProcessors] = None, - is_spec_decode: bool = False, - is_pooling_model: bool = False, - ): + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, + is_spec_decode: bool = False, + is_pooling_model: bool = False, + num_speculative_tokens: int = 0, + kernel_block_sizes: Optional[list[list[int]]] = None): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs @@ -143,7 +157,8 @@ class InputBatch: pin_memory=pin_memory, device=device, block_sizes=block_sizes, - ) + num_speculative_tokens=num_speculative_tokens, + kernel_sizes=kernel_block_sizes) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), @@ -218,6 +233,14 @@ class InputBatch: self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() + # Speculative decoding + self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), + dtype=torch.int64, + device="cpu", + pin_memory=pin_memory) + self.num_accepted_tokens_cpu = \ + self.num_accepted_tokens_cpu_tensor.numpy() + # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) @@ -266,6 +289,11 @@ class InputBatch: self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -407,6 +435,9 @@ class InputBatch: else: raise NotImplementedError(request) + # Speculative decoding: by default 1 token is generated. + self.num_accepted_tokens_cpu[req_index] = 1 + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -506,6 +537,8 @@ class InputBatch: self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ + self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -612,6 +645,8 @@ class InputBatch: empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] + self.num_accepted_tokens_cpu[ + empty_index] = self.num_accepted_tokens_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator @@ -726,20 +761,13 @@ class InputBatch: pooling_params = [ self.pooling_params[req_id] for req_id in self.req_ids ] - if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): - return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), - prompt_token_ids=self.sampling_metadata.prompt_token_ids, - pooling_params=pooling_params, - ) - else: - return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]), - prompt_token_ids=self.sampling_metadata.prompt_token_ids, - pooling_params=pooling_params, - ) + + return PoolingMetadata( + prompt_lens=torch.from_numpy( + self.num_prompt_tokens[:self.num_reqs]), + prompt_token_ids=self.sampling_metadata.prompt_token_ids, + pooling_params=pooling_params, + ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 1062d47..dc82ece 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -18,18 +18,18 @@ # import copy -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn import torch_npu import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions +from torch_npu.profiler import dynamic_profile as dp from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest @@ -38,22 +38,31 @@ from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.worker.worker_base import WorkerBase -from vllm_ascend.ascend_config import init_ascend_config +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (init_ascend_soc_version, register_ascend_customop, sleep_mode_enabled, - try_register_lib, vllm_version_is) + try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")): - from vllm.v1.outputs import DraftTokenIds -else: - DraftTokenIds = None +torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 +from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402 + +torch_non_c_binding_in_graph_functions_npu = dict.fromkeys( + ["torch.npu.current_stream"], + TorchInGraphFunctionVariable, +) # noqa: E402 +torch_non_c_binding_in_graph_functions_npu[ + "torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402 +torch._dynamo.trace_rules.torch_name_rule_map.append( + torch_non_c_binding_in_graph_functions_npu) # noqa: E402 class NPUWorker(WorkerBase): @@ -75,10 +84,21 @@ class NPUWorker(WorkerBase): from vllm_ascend import ops ops.register_dummy_fusion_op() _register_atb_extensions() - register_ascend_customop() + register_ascend_customop(vllm_config) # init ascend config and soc version init_ascend_config(vllm_config) init_ascend_soc_version() + if get_ascend_config().use_sfa: + # Direct import instead of using try_register_lib to ensure proper error handling when + # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) + # yapf: disable + import custom_ops # type: ignore # noqa + + # yapf: enable + logger.info( + "custom_ops module loaded successfully. Custom operators like " + "torch.ops.custom.npu_sparse_flash_attention are now available." + ) super().__init__(vllm_config=vllm_config, local_rank=local_rank, @@ -103,6 +123,15 @@ class NPUWorker(WorkerBase): init_cached_hf_modules() self.profiler = self._init_profiler() + if sleep_mode_enabled(): + # Buffers saved before sleep + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + + # FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170 + from vllm.model_executor.layers.linear import \ + WEIGHT_LOADER_V2_SUPPORTED + if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: + WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") def sleep(self, level: int = 1) -> None: if not sleep_mode_enabled(): @@ -110,6 +139,13 @@ class NPUWorker(WorkerBase): "Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1." ) free_bytes_before_sleep = NPUPlatform.mem_get_info()[0] + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() + for name, buffer in model.named_buffers() + } allocator = CaMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) free_bytes_after_sleep, total = NPUPlatform.mem_get_info() @@ -129,6 +165,14 @@ class NPUWorker(WorkerBase): allocator = CaMemAllocator.get_instance() allocator.wake_up(tags=tags) + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -195,36 +239,42 @@ class NPUWorker(WorkerBase): def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: + # enable msMonitor to monitor the performance of vllm-ascend + if envs_ascend.MSMONITOR_USE_DAEMON: + dp.step() + intermediate_tensors = None - if not get_pp_group().is_first_rank: + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): + return output + + assert isinstance(output, IntermediateTensors) parallel_config = self.vllm_config.parallel_config - if parallel_config.distributed_executor_backend != "external_launcher" \ - and not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - if not has_kv_transfer_group(): - return None + assert parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank - kv_connector_output = output.kv_connector_output - finished_sending = kv_connector_output.finished_sending - finished_recving = kv_connector_output.finished_recving + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None - new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - new_output.kv_connector_output = kv_connector_output - return new_output - - assert isinstance(output, ModelRunnerOutput) + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output return output def load_model(self) -> None: @@ -242,6 +292,7 @@ class NPUWorker(WorkerBase): def compile_or_warm_up_model(self) -> None: # Note: need to adapt for graph mode. + self.model_runner.eplb_warmup() warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy() if not self.model_config.enforce_eager: @@ -254,10 +305,19 @@ class NPUWorker(WorkerBase): self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() + # Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) + # may cause performance degradation at runtime. + self._warm_up_atb() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. NPUPlatform.seed_everything(self.model_config.seed) + def _warm_up_atb(self): + x = torch.rand((2, 4), dtype=torch.float16).npu() + weight = torch.rand((2, 4), dtype=torch.float16).npu() + c = torch.rand((4, 4), dtype=torch.float32).npu() + torch_npu._npu_matmul_add_fp32(x, weight, c) + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -313,6 +373,10 @@ class NPUWorker(WorkerBase): # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs_vllm.VLLM_TORCH_PROFILER_DIR: + if envs_ascend.MSMONITOR_USE_DAEMON: + raise RuntimeError( + "MSMONITOR_USE_DAEMON and VLLM_TORCH_PROFILER_DIR cannot be both set at the same time." + ) torch_profiler_trace_dir = envs_vllm.VLLM_TORCH_PROFILER_DIR logger.info("Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir) From 5a0e920ec1541d9ca6d7964ba35e8896865d4591 Mon Sep 17 00:00:00 2001 From: luopingyi Date: Tue, 21 Oct 2025 10:17:39 +0800 Subject: [PATCH 2/2] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 918e5d7..d8aec83 100644 --- a/README.md +++ b/README.md @@ -77,5 +77,5 @@ curl -X POST http://localhost:10086/v1/chat/completions \ | Version | Release type | Doc | |------------|--------------|--------------------------------------| -|v0.10.1rc1| 最新RC版本 |请查看[快速开始](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)和[安装指南](https://vllm-ascend.readthedocs.io/en/latest/installation.html)了解更多| +|v0.11.0rc0| 最新RC版本 |请查看[快速开始](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)和[安装指南](https://vllm-ascend.readthedocs.io/en/latest/installation.html)了解更多| |v0.9.1| 最新正式/稳定版本 |[快速开始](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/quick_start.html) and [安装指南](https://vllm-ascend.readthedocs.io/en/v0.9.1-dev/installation.html)了解更多|