From 6ea2afe5fabbe575db8c992bae176a4866385c02 Mon Sep 17 00:00:00 2001 From: Ronald Date: Wed, 7 Jan 2026 09:11:26 +0800 Subject: [PATCH] [Feature] implement basic framework for batch invariant (#5517) ### What this PR does / why we need it? This PR implement the basic framework for batch invariant, please see https://github.com/vllm-project/vllm-ascend/issues/5487. ### Does this PR introduce _any_ user-facing change? we reuse the function `vllm_is_batch_invariant` in vllm to judge if batch invariant is enabled. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 --------- Signed-off-by: Ronald1995 Signed-off-by: Lord_of_Ironhill Signed-off-by: zjchenn Signed-off-by: wangx700 Co-authored-by: Lord_of_Ironhill Co-authored-by: zjchenn Co-authored-by: wangx700 --- .github/workflows/_e2e_test.yaml | 1 + tests/e2e/singlecard/test_batch_invariant.py | 672 ++++++++++++++++++ vllm_ascend/batch_invariant.py | 82 +++ .../ops/triton/batch_invariant/__init__.py | 0 .../ops/triton/batch_invariant/matmul.py | 403 +++++++++++ .../ops/triton/batch_invariant/mean.py | 177 +++++ .../ops/triton/batch_invariant/rmsnorm.py | 153 ++++ .../ops/triton/batch_invariant/softmax.py | 29 + vllm_ascend/worker/worker.py | 2 + 9 files changed, 1519 insertions(+) create mode 100644 tests/e2e/singlecard/test_batch_invariant.py create mode 100644 vllm_ascend/batch_invariant.py create mode 100644 vllm_ascend/ops/triton/batch_invariant/__init__.py create mode 100644 vllm_ascend/ops/triton/batch_invariant/matmul.py create mode 100644 vllm_ascend/ops/triton/batch_invariant/mean.py create mode 100644 vllm_ascend/ops/triton/batch_invariant/rmsnorm.py create mode 100644 vllm_ascend/ops/triton/batch_invariant/softmax.py diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index f58e04e8..817cfea5 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -123,6 +123,7 @@ jobs: pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py + pytest -sv --durations=0 tests/e2e/singlecard/test_batch_invariant.py e2e-2-cards: name: multicard-2 diff --git a/tests/e2e/singlecard/test_batch_invariant.py b/tests/e2e/singlecard/test_batch_invariant.py new file mode 100644 index 00000000..d4fd423c --- /dev/null +++ b/tests/e2e/singlecard/test_batch_invariant.py @@ -0,0 +1,672 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/tests/v1/determinism/test_batch_invariant.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os +import random + +import pytest +import torch +from vllm import LLM, SamplingParams + +from tests.e2e.conftest import cleanup_dist_env_and_memory + +DEFAULT_MODEL = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(autouse=True) +def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): + """Automatically enable batch invariant kernel overrides for all tests.""" + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + # Generate more realistic prompts that will actually produce varied tokens + # Use a mix of common English text patterns + + prompt_templates = [ + # Question-answer style + "Question: What is the capital of France?\nAnswer: The capital of France is", + "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", + "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", + # Story/narrative style + "Once upon a time in a distant galaxy, there lived", + "The old man walked slowly down the street, remembering", + "In the year 2157, humanity finally discovered", + # Technical/code style + "To implement a binary search tree in Python, first we need to", + "The algorithm works by iterating through the array and", + "Here's how to optimize database queries using indexing:", + # Factual/informative style + "The Renaissance was a period in European history that", + "Climate change is caused by several factors including", + "The human brain contains approximately 86 billion neurons which", + # Conversational style + "I've been thinking about getting a new laptop because", + "Yesterday I went to the store and bought", + "My favorite thing about summer is definitely", + ] + + # Pick a random template + base_prompt = random.choice(prompt_templates) + + if max_words < min_words: + max_words = min_words + target_words = random.randint(min_words, max_words) + + if target_words > 50: + # For longer prompts, repeat context + padding_text = ( + " This is an interesting topic that deserves more explanation. " * + (target_words // 50)) + base_prompt = base_prompt + padding_text + + return base_prompt + + +def _extract_step_logprobs(request_output): + if getattr(request_output, "outputs", None): + inner = request_output.outputs[0] + if hasattr(inner, "logprobs") and inner.logprobs is not None: + t = torch.tensor( + [ + inner.logprobs[i][tid].logprob + for i, tid in enumerate(inner.token_ids) + ], + dtype=torch.float32, + ) + return t, inner.token_ids + + return None, None + + +@pytest.mark.timeout(1000) +def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( + monkeypatch: pytest.MonkeyPatch): + """ + Ensures that the same request (the 'needle' prompt) yields identical output + whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64), + using the high-level v1 LLM() API only (no manual batching). + + Strategy: + - Create two LLM engines with identical config except max_num_seqs: 1 vs N. + - Compute a baseline output for the needle prompt with the bs=1 engine. + - For many trials, generate a batch (size N) where the needle appears at a + random position among random filler prompts using the bs=N engine. + - Track how many trials match vs mismatch, and report totals at the end. + The test fails if any mismatches occur, but we still dump pass/fail + counts. + + Notes: + - Use seeded stochastic sampling with a fixed seed to test determinism. + - Outputs are intentionally longer and sampled at higher temperature/top_p + to produce a more random-sounding phrase, yet remain deterministic by + seed. + - Keep max_tokens and max_model_len bounded for speed and memory use. + """ + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + + # Allow overrides from environment (useful for CI tuning) + model = DEFAULT_MODEL + + num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) + max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "144")) + min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) + max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048")) + assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle." + + # Keep GPU memory usage low to avoid startup allocation failures. + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.95")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120")) + + # Sampling parameters: longer outputs with a more random-sounding + # continuation,but still deterministic due to fixed seed. + temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0")) + top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95")) + max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "35")) + + sampling = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=20240919, + ) + + needle_prompt = "There once was a " + + llm = None + try: + # Engine with bs=1 behavior + llm = LLM( + model=model, + max_num_seqs=max_batch_size, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + dtype="bfloat16", + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + enable_prefix_caching=False, + enforce_eager=True, + distributed_executor_backend="mp", + # Enable for MOE models + # enable_expert_parallel=True, + ) + + # Baseline generation for the needle prompt alone. + baseline_out = llm.generate([needle_prompt], sampling) + assert len(baseline_out) == 1 + assert len(baseline_out[0].outputs) >= 1 + baseline_text = baseline_out[0].outputs[0].text + + mismatches = 0 + + for trial in range(num_trials): + # Create a batch of size `max_batch_size` and insert the needle at + # a random index + prompts: list[str] = [] + batch_size = random.randint(max_batch_size // 2, max_batch_size) + needle_pos = random.randint(0, batch_size - 1) + for i in range(batch_size): + if i == needle_pos: + prompts.append(needle_prompt) + else: + prompts.append( + _random_prompt(min_random_prompt, max_random_prompt)) + + # Generate with the larger-batch engine + outputs = llm.generate(prompts, sampling) + # Find the needle output by position + needle_output = outputs[needle_pos] + assert needle_output.prompt == needle_prompt + assert len(needle_output.outputs) >= 1 + text = needle_output.outputs[0].text + + if text != baseline_text: + print( + f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n") + mismatches += 1 + + passes = num_trials - mismatches + # Dump how many passed vs failed + print(f"[determinism] total={num_trials}, passed={passes}, " + f"failed={mismatches}, max_batch_size={max_batch_size}") + + if mismatches > 0: + pytest.fail( + f"Nondeterministic outputs detected: {mismatches} failed out " + f"of {num_trials} trials (max_batch_size={max_batch_size}).") + + finally: + del llm + cleanup_dist_env_and_memory() + + +def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( + monkeypatch: pytest.MonkeyPatch): + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = DEFAULT_MODEL + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + # For batch invariance, disable custom all-reduce to ensure deterministic + # all-reduce operations (custom all-reduce may not be deterministic) + from vllm_ascend.batch_invariant import vllm_is_batch_invariant + + disable_custom_ar = vllm_is_batch_invariant() + + if disable_custom_ar: + print(f"\n{'=' * 80}") + print( + f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})" + ) + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + gpu_memory_utilization=0.9, + enforce_eager=True, + distributed_executor_backend="mp", + ) + + # Use more realistic prompts for better token generation + prompts = [_random_prompt(10, 50) for i in range(32)] + + sp = SamplingParams( + temperature=0.6, + top_p=1.0, + max_tokens=8, + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + + bs1_logprobs_per_prompt = [] + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print( + f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..." + ) + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") + + # BS=N: run prompts in a batch and collect logprobs per step for each + # prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bsN_logprobs_per_prompt = [] + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. + failed_prompts = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, + )): + if len(logprobs_bs1) != len(logprobs_bsN): + reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)") + failed_prompts.append({ + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + }) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + failed_prompts.append({ + "prompt_idx": + i, + "step": + "sampling", + "reason": + "Different tokens sampled", + "prompt_preview": + prompts[i][:100], + "bs1_tokens": + tokens_bs1, + "bsN_tokens": + tokens_bsN, + "bs1_all_logprobs": + [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))], + "bsN_all_logprobs": + [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))], + }) + continue + + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a.shape != b.shape: + failed_prompts.append({ + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + }) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + # Print which token failed + print( + f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}" + ) + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + failed_prompts.append({ + "prompt_idx": + i, + "step": + t, + "reason": + f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": + prompts[i][:100], + "bs1_tokens": + tokens_bs1, + "bsN_tokens": + tokens_bsN, + "bs1_all_logprobs": [ + logprobs_bs1[s].tolist() + for s in range(len(logprobs_bs1)) + ], + "bsN_all_logprobs": [ + logprobs_bsN[s].tolist() + for s in range(len(logprobs_bsN)) + ], + }) + break + del llm + cleanup_dist_env_and_memory() + # Print summary of all failures + if failed_prompts: + print(f"\n{'=' * 80}") + fail_msg = (f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/" + f"{len(prompts)} prompts failed") + print(fail_msg) + print(f"{'=' * 80}") + for fail in failed_prompts: + print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):") + print(f" Reason: {fail['reason']}") + print(f" Preview: {fail['prompt_preview']}...") + + # Always show the tokens + if "bs1_tokens" in fail: + print(f" BS=1 tokens: {fail['bs1_tokens']}") + if "bsN_tokens" in fail: + print(f" BS=N tokens: {fail['bsN_tokens']}") + + if "bs1_all_logprobs" in fail: + print( + f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:" + ) + for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]): + print(f" Step {step_idx}: {logprobs}") + print( + f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:" + ) + for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]): + print(f" Step {step_idx}: {logprobs}") + print(f"{'=' * 80}\n") + + # Fail the test with summary + msg = (f"Batch invariance violated in {len(failed_prompts)}/" + f"{len(prompts)} prompts. See output above for details.") + pytest.fail(msg) + + +def test_simple_generation(monkeypatch: pytest.MonkeyPatch): + """ + Simple test that runs the model with a basic prompt and prints the output. + Useful for quick smoke testing and debugging. + """ + model = DEFAULT_MODEL + + llm = LLM( + model=model, + max_num_seqs=1, + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + gpu_memory_utilization=0.9, + max_model_len=2048, + dtype="float16", + enable_prefix_caching=False, + enforce_eager=True, + distributed_executor_backend="mp", + ) + + prompt = "The capital of France is" + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=20, + ) + + print(f"\n{'=' * 80}") + print("Running simple generation test") + print(f"Prompt: '{prompt}'") + print(f"{'=' * 80}\n") + + try: + outputs = llm.generate([prompt], sampling_params) + + assert len(outputs) == 1 + output_text = outputs[0].outputs[0].text + + print(f"Output: '{output_text}'") + print(f"\n{'=' * 80}") + print(f"Full completion: '{prompt}{output_text}'") + print(f"{'=' * 80}\n") + + finally: + del llm + cleanup_dist_env_and_memory() + + +def test_logprobs_without_batch_invariance_should_fail( + monkeypatch: pytest.MonkeyPatch): + """ + This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN. + It DISABLES batch invariance mode and expects to see non-deterministic behavior + between BS=1 and BS=N runs. This demonstrates that batch invariance is actually + doing something useful. + + The test will PASS if we detect differences (proving batch invariance matters). + The test will FAIL if everything matches (suggesting batch invariance isn't needed). + """ + # CRITICAL: Disable batch invariance for this test + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = DEFAULT_MODEL + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + print(f"\n{'=' * 80}") + print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior") + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + enforce_eager=True, + distributed_executor_backend="mp", + ) + + # build ragged prompts to change shapes significantly across BS=1 vs BS=N + long_min = int(os.getenv("VLLM_MIN_PROMPT", "768")) + long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048")) + prompts: list[str] = [] + options = [ + (max(long_min, 1536), max(long_max, 3072)), # very long + (max(1024, long_min), max(2048, long_max)), # long + (256, 512), # mid + (10, 20), # short + ] + + for _ in range(32): + lo, hi = random.choice(options) + prompts.append(_random_prompt(lo, hi)) + + sp = SamplingParams( + temperature=0.6, + top_p=1.0, + max_tokens=8, + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + + bs1_logprobs_per_prompt = [] + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print( + f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..." + ) + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") + + # BS=N: run prompts in a batch and collect logprobs per step for each prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bsN_logprobs_per_prompt = [] + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. + differences_found = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, + )): + if len(logprobs_bs1) != len(logprobs_bsN): + reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)") + differences_found.append({ + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + }) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + differences_found.append({ + "prompt_idx": i, + "step": "sampling", + "reason": "Different tokens sampled", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + }) + continue + + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a.shape != b.shape: + differences_found.append({ + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + }) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, " + f"Token {t}: max_diff={max_diff:.6e}") + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + differences_found.append({ + "prompt_idx": i, + "step": t, + "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + }) + break + del llm + cleanup_dist_env_and_memory() + # Print summary + print(f"\n{'=' * 80}") + if differences_found: + success_msg = ( + f"✓ SUCCESS: Batch invariance is doing something! " + f"Found {len(differences_found)}/{len(prompts)} prompts " + f"with differences when batch invariance was DISABLED.") + print(success_msg) + print(f"{'=' * 80}") + for diff in differences_found: + print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):") + print(f" Reason: {diff['reason']}") + print(f" Preview: {diff['prompt_preview']}...") + if "bs1_tokens" in diff: + print(f" BS=1 tokens: {diff['bs1_tokens']}") + if "bsN_tokens" in diff: + print(f" BS=N tokens: {diff['bsN_tokens']}") + print(f"{'=' * 80}\n") + # Test PASSES because we found differences (batch invariance matters!) + return + else: + # Test FAILS because everything matched even without batch invariance + fail_msg = ( + f"✗ UNEXPECTED: All {len(prompts)} prompts matched " + f"between BS=1 and BS=N even with batch invariance DISABLED. " + f"This suggests batch invariance might not be necessary, " + f"or the test needs more sensitive prompts.") + print(fail_msg) + print(f"{'=' * 80}\n") + pytest.fail(fail_msg) diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py new file mode 100644 index 00000000..50acadb1 --- /dev/null +++ b/vllm_ascend/batch_invariant.py @@ -0,0 +1,82 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os + +import torch +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.triton_utils import HAS_TRITON + +logger = init_logger(__name__) + +if HAS_TRITON: + from vllm_ascend.ops.triton.batch_invariant.matmul import ( + addmm_batch_invariant, bmm_batch_invariant, linear_batch_invariant, + matmul_batch_invariant, mm_batch_invariant) + + +def override_envs_for_invariance(): + # TODO(Ronald) set attntion backend to deterministic mode + + # enabling NZ mode introduces NZ format input to the triton operator, + # resulting in accuracy anomalies. + os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0" + + # communication determinism settings + os.environ["HCCL_DETERMINISTIC"] = "true" + os.environ["LCCL_DETERMINISTIC"] = "1" + + +_batch_invariant_LIB = None + + +def enable_batch_invariant_mode(): + global _batch_invariant_LIB + + _batch_invariant_LIB = torch.library.Library("aten", "IMPL") + + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + + +def init_batch_invariance(): + """ + Initialize batch-invariant mode for vLLM on Ascend NPU. + + This function: + 1. Sets environment variables for deterministic computation + 2. Registers batch-invariant implementations for torch operators + 3. Enables batch-invariant flash attention + + Call this function early in your application, or set VLLM_BATCH_INVARIANT=1 + environment variable to enable automatically. + """ + if vllm_is_batch_invariant(): + if HAS_TRITON: + logger.info( + "Enabling batch-invariant mode for vLLM on Ascend NPU.", ) + override_envs_for_invariance() + enable_batch_invariant_mode() + else: + logger.warning( + "Batch-invariant mode requested but Triton is not available." + "skipping batch-invariant initialization.", ) diff --git a/vllm_ascend/ops/triton/batch_invariant/__init__.py b/vllm_ascend/ops/triton/batch_invariant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/ops/triton/batch_invariant/matmul.py b/vllm_ascend/ops/triton/batch_invariant/matmul.py new file mode 100644 index 00000000..0b7934ad --- /dev/null +++ b/vllm_ascend/ops/triton/batch_invariant/matmul.py @@ -0,0 +1,403 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from triton.runtime import driver # type: ignore +from vllm.triton_utils import tl, triton + + +@triton.jit +def matmul_bias_persistent_kernel( + # Input tensor pointers + x_ptr, + y_ptr, + bias_ptr, + output_ptr, + # Matrix dimensions + M, + N, + K, + # Stride information + stride_xm, + stride_xk, # Strides of x: [M, K] + stride_yk, + stride_yn, # Strides of y: [K, N] + stride_bias, # Stride of bias: [N] + stride_outm, + stride_outn, # Strides of output: [M, N] + # Whether to use bias + has_bias: tl.constexpr, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) # Row block ID + pid_n = tl.program_id(1) # Column block ID + + # Calculate the starting position of the current block in the matrix + rm_start = pid_m * BLOCK_M + rn_start = pid_n * BLOCK_N + + # Create index ranges + rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M] + rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N] + rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K] + + # Initialize accumulator to 0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Loop over the K dimension, processing BLOCK_K elements per iteration + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_start = k * BLOCK_K + # Calculate pointer offsets for x (row-major) + x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + + k_start) * stride_xk + # Calculate pointer offsets for y (row-major) + y_ptrs = y_ptr + (rk[:, None] + + k_start) * stride_yk + rn[None, :] * stride_yn + + # Create masks to prevent out-of-bounds access + x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K) + y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N) + + # Load data chunks from global memory + x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32) + y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32) + + # Compute matrix multiplication accumulation + acc += tl.dot(x_chunk, y_chunk, allow_tf32=False) + + # Add bias if the has_bias flag is set + if has_bias: + # Load bias values (broadcast to all rows) + bias_ptrs = bias_ptr + rn * stride_bias + bias_mask = rn < N + bias_vals = tl.load(bias_ptrs, mask=bias_mask, + other=0.0).to(tl.float32) + # Add bias to accumulator (automatic broadcasting) + acc += bias_vals[None, :] + + # Calculate output pointer positions + out_ptrs = output_ptr + rm[:, + None] * stride_outm + rn[None, :] * stride_outn + out_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Store result to global memory + tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask) + + +def matmul_persistent(x, y, bias=None): + """ + Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None) + + Parameters: + x: torch.Tensor, shape [M, K] + y: torch.Tensor, shape [K, N] + bias: torch.Tensor, shape [N] or None + + Returns: + output: torch.Tensor, shape [M, N] + """ + # Validate input shapes + assert x.dim() == 2, "x must be a 2D tensor" + assert y.dim() == 2, "y must be a 2D tensor" + assert x.shape[1] == y.shape[ + 0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}" + + M, K = x.shape + _, N = y.shape + # Validate bias shape (if not None) + if bias is not None: + assert bias.dim() == 1, "bias must be a 1D tensor" + assert y.shape[1] == bias.shape[ + 0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}" + + # Allocate output tensor (same data type as x) + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + + # Define block sizes (can be adjusted based on hardware) + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + + # Calculate grid size (one thread per block) + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + # Handle case when bias is None + if bias is None: + # Create a dummy bias tensor (will not be used as has_bias=False) + dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device) + has_bias = False + bias_stride = 0 + bias_to_pass = dummy_bias + else: + has_bias = True + bias_stride = bias.stride(0) + bias_to_pass = bias + # Launch kernel + matmul_bias_persistent_kernel[grid]( + x, + y, + bias_to_pass, + output, # Input/Output tensors + M, + N, + K, # Matrix dimensions + x.stride(0), + x.stride(1), # Strides of x + y.stride(0), + y.stride(1), # Strides of y + bias_stride, # Stride of bias (0 if bias is None) + output.stride(0), + output.stride(1), # Strides of output + has_bias, # Flag indicating whether to use bias + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + + return output + + +@triton.jit +def linear_persistent_kernel( + a_ptr, # Pointer to tensor a, shape [M, K] + b_ptr, # Pointer to tensor b, shape [N, K] + c_ptr, # Pointer to output tensor c, shape [M, N] + M, # Number of rows in tensor a + N, # Number of rows in tensor b (number of columns in output c) + K, # Number of columns in both tensor a and tensor b + stride_am, # Stride of tensor a along dimension M (typically K) + stride_ak, # Stride of tensor a along dimension K (typically 1) + stride_bn, # Stride of tensor b along dimension N (typically K) + stride_bk, # Stride of tensor b along dimension K (typically 1) + stride_cm, # Stride of tensor c along dimension M (typically N) + stride_cn, # Stride of tensor c along dimension N (typically 1) + BLOCK_M: tl.constexpr, # Block size for M dimension + BLOCK_N: tl.constexpr, # Block size for N dimension + BLOCK_K: tl.constexpr, # Block size for K dimension + NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension + NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension + GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size +): + # Get current program's 1D index (1D grid) + pid = tl.program_id(0) + total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks + + # Loop over multiple blocks assigned to the current program + for block_index in range(pid, total_blocks, GRID_SIZE): + # Convert 1D block index to 2D coordinates (m_block, n_block) + m_block = block_index // NUM_BLOCKS_N + n_block = block_index % NUM_BLOCKS_N + + # Calculate starting indices of the current output block + start_m = m_block * BLOCK_M + start_n = n_block * BLOCK_N + + # Create row and column index ranges within the current block + m_indices = start_m + tl.arange(0, BLOCK_M) + n_indices = start_n + tl.arange(0, BLOCK_N) + + # Create masks to handle boundaries + m_mask = m_indices < M + n_mask = n_indices < N + + # Initialize accumulator to 0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Loop over K dimension with step size BLOCK_K + for k_offset in range(0, K, BLOCK_K): + k_indices = k_offset + tl.arange(0, BLOCK_K) + k_mask = k_indices < K + + # Load block of tensor a: shape [BLOCK_M, BLOCK_K] + a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[ + None, :] * stride_ak + a_vals = tl.load(a_ptrs, + mask=m_mask[:, None] & k_mask[None, :], + other=0.0) + + # Load block of tensor b: shape [BLOCK_N, BLOCK_K] + b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[ + None, :] * stride_bk + b_vals = tl.load(b_ptrs, + mask=n_mask[:, None] & k_mask[None, :], + other=0.0) + + # Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N] + b_vals_transposed = tl.trans(b_vals) + + # Compute matrix multiplication: a_vals × b_vals_transposed + product = tl.dot(a_vals, b_vals_transposed) + acc += product + # Store result to output tensor c + c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[ + None, :] * stride_cn + tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :]) + + +def linear_persistent(x, y): + """ + Implement matrix multiplication with Triton: x @ y^T + Uses a fixed-size 1D grid + + Parameters: + x: torch.Tensor, shape [M, K] + y: torch.Tensor, shape [N, K] + + Returns: + output: torch.Tensor, shape [M, N] + """ + # Validate input shapes + assert x.dim() == 2, "x must be a 2D tensor" + assert y.dim() == 2, "y must be a 2D tensor" + assert x.shape[1] == y.shape[ + 1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}" + + M, K = x.shape + N, _ = y.shape + + # Allocate output tensor (same data type as x) + output = torch.zeros((M, N), dtype=x.dtype, device=x.device) + + # Define block sizes (can be adjusted based on hardware) + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + + # Calculate number of blocks per dimension (ceil division) + num_blocks_m = triton.cdiv(M, BLOCK_M) + num_blocks_n = triton.cdiv(N, BLOCK_N) + + # Set fixed 1D grid size + grid_size = driver.active.utils.get_device_properties( + torch.npu.current_device())["num_vectorcore"] // 2 + grid = (grid_size, ) + + # Launch kernel + linear_persistent_kernel[grid]( + a_ptr=x, + b_ptr=y, + c_ptr=output, + M=M, + N=N, + K=K, + stride_am=x.stride(0), + stride_ak=x.stride(1), + stride_bn=y.stride(0), + stride_bk=y.stride(1), + stride_cm=output.stride(0), + stride_cn=output.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension + NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension + GRID_SIZE=grid_size, # Fixed grid size + ) + + return output + + +def mm_batch_invariant(a, b): + return matmul_persistent(a, b) + + +def bmm_batch_invariant(a, b, *, out=None): + # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N) + # Process each batch separately with our persistent kernel + if a.ndim == 3 and b.ndim == 3: + results = [] + for i in range(a.shape[0]): + results.append(matmul_persistent(a[i], b[i])) + result = torch.stack(results, dim=0) + + if out is not None: + out.copy_(result) + return out + return result + else: + raise ValueError(f"bmm_batch_invariant expects 3D tensors, " + f"got shapes {a.shape} and {b.shape}") + + +def addmm_batch_invariant(bias, a, b): + return matmul_persistent(a, b, bias=bias) + + +def matmul_batch_invariant(a, b, *, out=None): + # torch.matmul can handle various dimensions + # For 2D x 2D, it's the same as matmul + if a.ndim == 2 and b.ndim == 2: + result = matmul_persistent(a, b) + if out is not None: + out.copy_(result) + return out + return result + elif a.ndim == 3 and b.ndim == 3: + # Handle batched case like bmm + return bmm_batch_invariant(a, b, out=out) + elif a.ndim == 3 and b.ndim == 2: + # Handle 3D x 2D: common for linear layers + # (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out) + # Reshape to 2D, do mm, reshape back + batch, seq, hidden = a.shape + a_2d = a.reshape(-1, hidden) + result_2d = matmul_persistent(a_2d, b) + result = result_2d.reshape(batch, seq, -1) + if out is not None: + out.copy_(result) + return out + return result + elif a.ndim == 2 and b.ndim == 3: + # Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N) + # By broadcasting `a` to 3D, we can reuse the batched matrix + # multiplication logic. + a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1) + return bmm_batch_invariant(a_expanded, b, out=out) + elif a.ndim == 4 and b.ndim == 4: + # Handle 4D attention tensors: [batch, heads, seq, dim] + # Reshape to 3D, process, reshape back + batch, heads, seq_a, dim_a = a.shape + _, _, dim_b, seq_b = b.shape + + # Reshape to [batch*heads, seq_a, dim_a] + a_3d = a.reshape(batch * heads, seq_a, dim_a) + b_3d = b.reshape(batch * heads, dim_b, seq_b) + + # Do batched matmul + result_3d = bmm_batch_invariant(a_3d, b_3d) + + # Reshape back to [batch, heads, seq_a, seq_b] + result = result_3d.reshape(batch, heads, seq_a, seq_b) + + if out is not None: + out.copy_(result) + return out + return result + else: + raise ValueError( + f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, " + f"3D x 2D, 2D x 3D, and 4D x 4D, " + f"got shapes {a.shape} and {b.shape}") + + +def linear_batch_invariant(input_, weight, bias=None): + output = linear_persistent(input_, weight) + + if bias is not None: + output = output + bias + return output diff --git a/vllm_ascend/ops/triton/batch_invariant/mean.py b/vllm_ascend/ops/triton/batch_invariant/mean.py new file mode 100644 index 00000000..0a13f734 --- /dev/null +++ b/vllm_ascend/ops/triton/batch_invariant/mean.py @@ -0,0 +1,177 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.triton_utils import tl, triton + + +@triton.jit +def mean_kernel( + input_ptr, + output_ptr, + input_stride0, + input_stride1, + input_stride2, + output_stride0, + output_stride1, + M, # size before reduction dim + N, # size of reduction dim + K, # size after reduction dim + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for computing mean along a single dimension. + Input is viewed as (M, N, K) where N is the dimension being reduced. + """ + # Program ID gives us which output element we're computing + pid = tl.program_id(0) + + # Compute output indices + m_idx = pid // K + k_idx = pid % K + + # Bounds check + if m_idx >= M or k_idx >= K: + return + # Accumulate sum across reduction dimension + acc = 0.0 + for n_start in range(0, N, BLOCK_SIZE): + n_offsets = n_start + tl.arange(0, BLOCK_SIZE) + mask = n_offsets < N + + # Calculate input indices + input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 + # Load and accumulate + vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + acc += tl.sum(vals) + + # Compute mean and store + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 + tl.store(output_ptr + output_idx, mean_val) + + +def mean_dim( + input_: torch.Tensor, + dim: int, + keepdim: bool = False, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """ + Triton implementation of torch.mean with single dimension reduction. + + Args: + input: Input tensor + dim: Single dimension along which to compute mean + keepdim: Whether to keep the reduced dimension + dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) + + Returns: + Tensor with mean values along specified dimension + """ + # Validate inputs + assert -input_.ndim <= dim < input_.ndim, ( + f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions") + + # Handle negative dim + if dim < 0: + dim = dim + input_.ndim + # Handle dtype + if dtype is None: + if input_.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + dtype = torch.float32 + else: + dtype = input_.dtype + # Convert input to appropriate dtype if needed + if input_.dtype != dtype: + input_ = input_.to(dtype) + + # Get input shape and strides + shape = list(input_.shape) + # Calculate dimensions for kernel + M = 1 + for i in range(dim): + M *= shape[i] + + N = shape[dim] + + K = 1 + for i in range(dim + 1, len(shape)): + K *= shape[i] + + # Reshape input to 3D view (M, N, K) + input_3d = input_.reshape(M, N, K) + + # Create output shape + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1:] + + # Create output tensor + output = torch.empty(output_shape, dtype=dtype, device=input_.device) + + # Reshape output for kernel + if keepdim: + output_2d = output.reshape(M, 1, K).squeeze(1) + else: + output_2d = output.reshape(M, K) + + # Launch kernel + grid = (M * K, ) + BLOCK_SIZE = 1024 + + mean_kernel[grid]( + input_3d, + output_2d, + input_3d.stride(0), + input_3d.stride(1), + input_3d.stride(2), + output_2d.stride(0), + output_2d.stride(1) if output_2d.ndim > 1 else 0, + M, + N, + K, + BLOCK_SIZE, + ) + + return output + + +def mean_batch_invariant( + input_: torch.Tensor, + dim: list[int], + keepdim: bool = False, + dtype: torch.dtype = torch.float16, +): + assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" + if len(dim) == 1: + return mean_dim(input_, dim[0], keepdim=keepdim) + else: + assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32 + }, ("only float types supported for now") + if len(dim) == 0: + dim = list(range(input_.ndim)) + n_elems = 1 + for d in dim: + n_elems *= input_.shape[d] + return torch.sum(input_, dim=dim, keepdim=keepdim, + dtype=torch.float32).to(dtype + or input_.dtype) / n_elems diff --git a/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py b/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py new file mode 100644 index 00000000..f4aa78a3 --- /dev/null +++ b/vllm_ascend/ops/triton/batch_invariant/rmsnorm.py @@ -0,0 +1,153 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from triton.runtime import driver # type: ignore +from vllm.triton_utils import tl, triton + + +@triton.jit +def _rms_norm_kernel( + input_ptr, + weight_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_rows, # 新增参数:总行数 + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute RMS normalization along the last dimension of a 2D tensor. + RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight + Each program handles multiple rows of the input tensor. + """ + pid = tl.program_id(0) + n_programs = tl.num_programs(0) + + rows_per_program = (n_rows + n_programs - 1) // n_programs + start_row = pid * rows_per_program + end_row = tl.minimum(start_row + rows_per_program, n_rows) + + for row_idx in range(start_row, end_row): + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Compute sum of squares in float32 to avoid overflow + sum_sq = tl.zeros([1], dtype=tl.float32) + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + vals_f32 = vals.to(tl.float32) + sq_vals = vals_f32 * vals_f32 + sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) + + # Step 2: Compute RMS (root mean square) in float32 + mean_sq = sum_sq / n_cols + rms = tl.sqrt(mean_sq + eps) + inv_rms = 1.0 / rms + + # Step 3: Normalize and apply weight + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) + vals_f32 = vals.to(tl.float32) + weight_f32 = weight.to(tl.float32) + output_f32 = vals_f32 * inv_rms * weight_f32 + output = output_f32.to(vals.dtype) + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def rms_norm( + input_: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Compute RMS normalization using Triton kernel with fixed grid size. + + RMS Norm normalizes the input by the root mean square and scales by weight: + output = input / sqrt(mean(input^2) + eps) * weight + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + Tensor with RMS normalization applied along the last dimension + """ + assert weight.dim() == 1, "Weight must be 1-dimensional" + assert input_.shape[-1] == weight.shape[0], ( + f"Input last dimension ({input_.shape[-1]}) must match " + f"weight dimension ({weight.shape[0]})") + + # Flatten all dimensions except the last one + original_shape = input_.shape + input_2d = input_.reshape(-1, input_.shape[-1]) + input_2d = input_2d.contiguous() + weight = weight.contiguous() + + n_rows, n_cols = input_2d.shape + + output = torch.empty_like(input_2d, dtype=input_.dtype) + BLOCK_SIZE = 1024 + max_grid_size = driver.active.utils.get_device_properties( + torch.npu.current_device())["num_vectorcore"] + + grid = (min(n_rows, max_grid_size), ) + + _rms_norm_kernel[grid]( + input_2d, + weight, + output, + input_2d.stride(0), + output.stride(0), + n_rows, + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output.reshape(original_shape) + + +def rms_norm_batch_invariant( + input_: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Batch-invariant wrapper for RMS normalization. + + This function provides a deterministic, batch-invariant implementation + of RMS normalization for use with the batch_invariant mode. + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + RMS normalized tensor + """ + return rms_norm(input_, weight, eps=eps) diff --git a/vllm_ascend/ops/triton/batch_invariant/softmax.py b/vllm_ascend/ops/triton/batch_invariant/softmax.py new file mode 100644 index 00000000..37bf75cc --- /dev/null +++ b/vllm_ascend/ops/triton/batch_invariant/softmax.py @@ -0,0 +1,29 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import torch + + +def softmax_batch_invariant(input_, dim, dtype=None): + # Compute softmax in a deterministic way + # First subtract max for numerical stability (standard practice) + input_max = torch.amax(input_, dim=dim, keepdim=True) + input_ = input_ - input_max + exp_x = torch.exp(input_) + sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / sum_exp_x diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 6e437208..2da2551e 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -50,6 +50,7 @@ from vllm.v1.worker.workspace import init_workspace_manager import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config +from vllm_ascend.batch_invariant import init_batch_invariance from vllm_ascend.cpu_binding import bind_cpus from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel @@ -453,6 +454,7 @@ class NPUWorker(WorkerBase): def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" + init_batch_invariance() init_distributed_environment(self.parallel_config.world_size, self.rank, self.distributed_init_method, self.local_rank, "hccl")