[Feature] implement basic framework for batch invariant (#5517)

### What this PR does / why we need it?
This PR implement the basic framework for batch invariant, please see
https://github.com/vllm-project/vllm-ascend/issues/5487.
### Does this PR introduce _any_ user-facing change?
we reuse the function `vllm_is_batch_invariant` in vllm to judge if
batch invariant is enabled.

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Lord_of_Ironhill <suiweiyi@huawei.com>
Signed-off-by: zjchenn <zjchenn@gmail.com>
Signed-off-by: wangx700 <wangxin700@huawei.com>
Co-authored-by: Lord_of_Ironhill <suiweiyi@huawei.com>
Co-authored-by: zjchenn <zjchenn@gmail.com>
Co-authored-by: wangx700 <wangxin700@huawei.com>
This commit is contained in:
Ronald
2026-01-07 09:11:26 +08:00
committed by GitHub
parent bdedf3c9f8
commit 6ea2afe5fa
9 changed files with 1519 additions and 0 deletions

View File

@@ -123,6 +123,7 @@ jobs:
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py
pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py
pytest -sv --durations=0 tests/e2e/singlecard/test_batch_invariant.py
e2e-2-cards:
name: multicard-2

View File

@@ -0,0 +1,672 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/tests/v1/determinism/test_batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import os
import random
import pytest
import torch
from vllm import LLM, SamplingParams
from tests.e2e.conftest import cleanup_dist_env_and_memory
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
# Use a mix of common English text patterns
prompt_templates = [
# Question-answer style
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
# Story/narrative style
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
# Technical/code style
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
# Factual/informative style
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
# Conversational style
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
# Pick a random template
base_prompt = random.choice(prompt_templates)
if max_words < min_words:
max_words = min_words
target_words = random.randint(min_words, max_words)
if target_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. " *
(target_words // 50))
base_prompt = base_prompt + padding_text
return base_prompt
def _extract_step_logprobs(request_output):
if getattr(request_output, "outputs", None):
inner = request_output.outputs[0]
if hasattr(inner, "logprobs") and inner.logprobs is not None:
t = torch.tensor(
[
inner.logprobs[i][tid].logprob
for i, tid in enumerate(inner.token_ids)
],
dtype=torch.float32,
)
return t, inner.token_ids
return None, None
@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
monkeypatch: pytest.MonkeyPatch):
"""
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
using the high-level v1 LLM() API only (no manual batching).
Strategy:
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
- Compute a baseline output for the needle prompt with the bs=1 engine.
- For many trials, generate a batch (size N) where the needle appears at a
random position among random filler prompts using the bs=N engine.
- Track how many trials match vs mismatch, and report totals at the end.
The test fails if any mismatches occur, but we still dump pass/fail
counts.
Notes:
- Use seeded stochastic sampling with a fixed seed to test determinism.
- Outputs are intentionally longer and sampled at higher temperature/top_p
to produce a more random-sounding phrase, yet remain deterministic by
seed.
- Keep max_tokens and max_model_len bounded for speed and memory use.
"""
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
# Allow overrides from environment (useful for CI tuning)
model = DEFAULT_MODEL
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "144"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
# Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.95"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
# Sampling parameters: longer outputs with a more random-sounding
# continuation,but still deterministic due to fixed seed.
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "35"))
sampling = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=20240919,
)
needle_prompt = "There once was a "
llm = None
try:
# Engine with bs=1 behavior
llm = LLM(
model=model,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=True,
distributed_executor_backend="mp",
# Enable for MOE models
# enable_expert_parallel=True,
)
# Baseline generation for the needle prompt alone.
baseline_out = llm.generate([needle_prompt], sampling)
assert len(baseline_out) == 1
assert len(baseline_out[0].outputs) >= 1
baseline_text = baseline_out[0].outputs[0].text
mismatches = 0
for trial in range(num_trials):
# Create a batch of size `max_batch_size` and insert the needle at
# a random index
prompts: list[str] = []
batch_size = random.randint(max_batch_size // 2, max_batch_size)
needle_pos = random.randint(0, batch_size - 1)
for i in range(batch_size):
if i == needle_pos:
prompts.append(needle_prompt)
else:
prompts.append(
_random_prompt(min_random_prompt, max_random_prompt))
# Generate with the larger-batch engine
outputs = llm.generate(prompts, sampling)
# Find the needle output by position
needle_output = outputs[needle_pos]
assert needle_output.prompt == needle_prompt
assert len(needle_output.outputs) >= 1
text = needle_output.outputs[0].text
if text != baseline_text:
print(
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
mismatches += 1
passes = num_trials - mismatches
# Dump how many passed vs failed
print(f"[determinism] total={num_trials}, passed={passes}, "
f"failed={mismatches}, max_batch_size={max_batch_size}")
if mismatches > 0:
pytest.fail(
f"Nondeterministic outputs detected: {mismatches} failed out "
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
finally:
del llm
cleanup_dist_env_and_memory()
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
monkeypatch: pytest.MonkeyPatch):
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = DEFAULT_MODEL
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm_ascend.batch_invariant import vllm_is_batch_invariant
disable_custom_ar = vllm_is_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")
print(
f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})"
)
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
gpu_memory_utilization=0.9,
enforce_eager=True,
distributed_executor_backend="mp",
)
# Use more realistic prompts for better token generation
prompts = [_random_prompt(10, 50) for i in range(32)]
sp = SamplingParams(
temperature=0.6,
top_p=1.0,
max_tokens=8,
seed=1234,
logprobs=5,
)
# BS=1: run prompts individually and collect logprobs per step.
print("\n" + "=" * 80)
print("STARTING BS=1 RUNS (each prompt individually)")
print("=" * 80 + "\n")
bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
bs1_logprobs_per_prompt.append(step_logprobs)
bs1_tokens_per_prompt.append(token_ids)
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
# BS=N: run prompts in a batch and collect logprobs per step for each
# prompt.
print("\n" + "=" * 80)
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
print("=" * 80 + "\n")
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts)
bsN_logprobs_per_prompt = []
bsN_tokens_per_prompt = []
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
for idx, o in enumerate(outs_batched):
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
step_logprobs, token_ids = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
bsN_logprobs_per_prompt.append(step_logprobs)
bsN_tokens_per_prompt.append(token_ids)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
failed_prompts = []
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)):
if len(logprobs_bs1) != len(logprobs_bsN):
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
f"vs {len(logprobs_bsN)} (BS=N)")
failed_prompts.append({
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
continue
# Check if tokens match first
if tokens_bs1 != tokens_bsN:
failed_prompts.append({
"prompt_idx":
i,
"step":
"sampling",
"reason":
"Different tokens sampled",
"prompt_preview":
prompts[i][:100],
"bs1_tokens":
tokens_bs1,
"bsN_tokens":
tokens_bsN,
"bs1_all_logprobs":
[logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
"bsN_all_logprobs":
[logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
})
continue
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a.shape != b.shape:
failed_prompts.append({
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
break
if not torch.equal(a, b):
max_diff = torch.abs(a - b).max().item()
# Print which token failed
print(
f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}"
)
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
print(f" BS=1 logprob: {a.tolist()}")
print(f" BS=N logprob: {b.tolist()}")
failed_prompts.append({
"prompt_idx":
i,
"step":
t,
"reason":
f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview":
prompts[i][:100],
"bs1_tokens":
tokens_bs1,
"bsN_tokens":
tokens_bsN,
"bs1_all_logprobs": [
logprobs_bs1[s].tolist()
for s in range(len(logprobs_bs1))
],
"bsN_all_logprobs": [
logprobs_bsN[s].tolist()
for s in range(len(logprobs_bsN))
],
})
break
del llm
cleanup_dist_env_and_memory()
# Print summary of all failures
if failed_prompts:
print(f"\n{'=' * 80}")
fail_msg = (f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
f"{len(prompts)} prompts failed")
print(fail_msg)
print(f"{'=' * 80}")
for fail in failed_prompts:
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
print(f" Reason: {fail['reason']}")
print(f" Preview: {fail['prompt_preview']}...")
# Always show the tokens
if "bs1_tokens" in fail:
print(f" BS=1 tokens: {fail['bs1_tokens']}")
if "bsN_tokens" in fail:
print(f" BS=N tokens: {fail['bsN_tokens']}")
if "bs1_all_logprobs" in fail:
print(
f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:"
)
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
print(f" Step {step_idx}: {logprobs}")
print(
f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:"
)
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
print(f" Step {step_idx}: {logprobs}")
print(f"{'=' * 80}\n")
# Fail the test with summary
msg = (f"Batch invariance violated in {len(failed_prompts)}/"
f"{len(prompts)} prompts. See output above for details.")
pytest.fail(msg)
def test_simple_generation(monkeypatch: pytest.MonkeyPatch):
"""
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
model = DEFAULT_MODEL
llm = LLM(
model=model,
max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="float16",
enable_prefix_caching=False,
enforce_eager=True,
distributed_executor_backend="mp",
)
prompt = "The capital of France is"
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=20,
)
print(f"\n{'=' * 80}")
print("Running simple generation test")
print(f"Prompt: '{prompt}'")
print(f"{'=' * 80}\n")
try:
outputs = llm.generate([prompt], sampling_params)
assert len(outputs) == 1
output_text = outputs[0].outputs[0].text
print(f"Output: '{output_text}'")
print(f"\n{'=' * 80}")
print(f"Full completion: '{prompt}{output_text}'")
print(f"{'=' * 80}\n")
finally:
del llm
cleanup_dist_env_and_memory()
def test_logprobs_without_batch_invariance_should_fail(
monkeypatch: pytest.MonkeyPatch):
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
doing something useful.
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = DEFAULT_MODEL
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}")
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=True,
distributed_executor_backend="mp",
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
prompts: list[str] = []
options = [
(max(long_min, 1536), max(long_max, 3072)), # very long
(max(1024, long_min), max(2048, long_max)), # long
(256, 512), # mid
(10, 20), # short
]
for _ in range(32):
lo, hi = random.choice(options)
prompts.append(_random_prompt(lo, hi))
sp = SamplingParams(
temperature=0.6,
top_p=1.0,
max_tokens=8,
seed=1234,
logprobs=5,
)
# BS=1: run prompts individually and collect logprobs per step.
print("\n" + "=" * 80)
print("STARTING BS=1 RUNS (each prompt individually)")
print("=" * 80 + "\n")
bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
bs1_logprobs_per_prompt.append(step_logprobs)
bs1_tokens_per_prompt.append(token_ids)
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
print("\n" + "=" * 80)
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
print("=" * 80 + "\n")
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts)
bsN_logprobs_per_prompt = []
bsN_tokens_per_prompt = []
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
for idx, o in enumerate(outs_batched):
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
step_logprobs, token_ids = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
bsN_logprobs_per_prompt.append(step_logprobs)
bsN_tokens_per_prompt.append(token_ids)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
differences_found = []
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)):
if len(logprobs_bs1) != len(logprobs_bsN):
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
f"vs {len(logprobs_bsN)} (BS=N)")
differences_found.append({
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
continue
# Check if tokens match first
if tokens_bs1 != tokens_bsN:
differences_found.append({
"prompt_idx": i,
"step": "sampling",
"reason": "Different tokens sampled",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
continue
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a.shape != b.shape:
differences_found.append({
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
break
if not torch.equal(a, b):
max_diff = torch.abs(a - b).max().item()
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
f"Token {t}: max_diff={max_diff:.6e}")
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
print(f" BS=1 logprob: {a.tolist()}")
print(f" BS=N logprob: {b.tolist()}")
differences_found.append({
"prompt_idx": i,
"step": t,
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
break
del llm
cleanup_dist_env_and_memory()
# Print summary
print(f"\n{'=' * 80}")
if differences_found:
success_msg = (
f"✓ SUCCESS: Batch invariance is doing something! "
f"Found {len(differences_found)}/{len(prompts)} prompts "
f"with differences when batch invariance was DISABLED.")
print(success_msg)
print(f"{'=' * 80}")
for diff in differences_found:
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
print(f" Reason: {diff['reason']}")
print(f" Preview: {diff['prompt_preview']}...")
if "bs1_tokens" in diff:
print(f" BS=1 tokens: {diff['bs1_tokens']}")
if "bsN_tokens" in diff:
print(f" BS=N tokens: {diff['bsN_tokens']}")
print(f"{'=' * 80}\n")
# Test PASSES because we found differences (batch invariance matters!)
return
else:
# Test FAILS because everything matched even without batch invariance
fail_msg = (
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
f"between BS=1 and BS=N even with batch invariance DISABLED. "
f"This suggests batch invariance might not be necessary, "
f"or the test needs more sensitive prompts.")
print(fail_msg)
print(f"{'=' * 80}\n")
pytest.fail(fail_msg)

View File

@@ -0,0 +1,82 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import os
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.triton_utils import HAS_TRITON
logger = init_logger(__name__)
if HAS_TRITON:
from vllm_ascend.ops.triton.batch_invariant.matmul import (
addmm_batch_invariant, bmm_batch_invariant, linear_batch_invariant,
matmul_batch_invariant, mm_batch_invariant)
def override_envs_for_invariance():
# TODO(Ronald) set attntion backend to deterministic mode
# enabling NZ mode introduces NZ format input to the triton operator,
# resulting in accuracy anomalies.
os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0"
# communication determinism settings
os.environ["HCCL_DETERMINISTIC"] = "true"
os.environ["LCCL_DETERMINISTIC"] = "1"
_batch_invariant_LIB = None
def enable_batch_invariant_mode():
global _batch_invariant_LIB
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
def init_batch_invariance():
"""
Initialize batch-invariant mode for vLLM on Ascend NPU.
This function:
1. Sets environment variables for deterministic computation
2. Registers batch-invariant implementations for torch operators
3. Enables batch-invariant flash attention
Call this function early in your application, or set VLLM_BATCH_INVARIANT=1
environment variable to enable automatically.
"""
if vllm_is_batch_invariant():
if HAS_TRITON:
logger.info(
"Enabling batch-invariant mode for vLLM on Ascend NPU.", )
override_envs_for_invariance()
enable_batch_invariant_mode()
else:
logger.warning(
"Batch-invariant mode requested but Triton is not available."
"skipping batch-invariant initialization.", )

View File

@@ -0,0 +1,403 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from triton.runtime import driver # type: ignore
from vllm.triton_utils import tl, triton
@triton.jit
def matmul_bias_persistent_kernel(
# Input tensor pointers
x_ptr,
y_ptr,
bias_ptr,
output_ptr,
# Matrix dimensions
M,
N,
K,
# Stride information
stride_xm,
stride_xk, # Strides of x: [M, K]
stride_yk,
stride_yn, # Strides of y: [K, N]
stride_bias, # Stride of bias: [N]
stride_outm,
stride_outn, # Strides of output: [M, N]
# Whether to use bias
has_bias: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0) # Row block ID
pid_n = tl.program_id(1) # Column block ID
# Calculate the starting position of the current block in the matrix
rm_start = pid_m * BLOCK_M
rn_start = pid_n * BLOCK_N
# Create index ranges
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over the K dimension, processing BLOCK_K elements per iteration
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_start = k * BLOCK_K
# Calculate pointer offsets for x (row-major)
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] +
k_start) * stride_xk
# Calculate pointer offsets for y (row-major)
y_ptrs = y_ptr + (rk[:, None] +
k_start) * stride_yk + rn[None, :] * stride_yn
# Create masks to prevent out-of-bounds access
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
# Load data chunks from global memory
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
# Compute matrix multiplication accumulation
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
# Add bias if the has_bias flag is set
if has_bias:
# Load bias values (broadcast to all rows)
bias_ptrs = bias_ptr + rn * stride_bias
bias_mask = rn < N
bias_vals = tl.load(bias_ptrs, mask=bias_mask,
other=0.0).to(tl.float32)
# Add bias to accumulator (automatic broadcasting)
acc += bias_vals[None, :]
# Calculate output pointer positions
out_ptrs = output_ptr + rm[:,
None] * stride_outm + rn[None, :] * stride_outn
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
# Store result to global memory
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
def matmul_persistent(x, y, bias=None):
"""
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [K, N]
bias: torch.Tensor, shape [N] or None
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[
0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
M, K = x.shape
_, N = y.shape
# Validate bias shape (if not None)
if bias is not None:
assert bias.dim() == 1, "bias must be a 1D tensor"
assert y.shape[1] == bias.shape[
0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
# Allocate output tensor (same data type as x)
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
# Define block sizes (can be adjusted based on hardware)
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
# Calculate grid size (one thread per block)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
# Handle case when bias is None
if bias is None:
# Create a dummy bias tensor (will not be used as has_bias=False)
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
has_bias = False
bias_stride = 0
bias_to_pass = dummy_bias
else:
has_bias = True
bias_stride = bias.stride(0)
bias_to_pass = bias
# Launch kernel
matmul_bias_persistent_kernel[grid](
x,
y,
bias_to_pass,
output, # Input/Output tensors
M,
N,
K, # Matrix dimensions
x.stride(0),
x.stride(1), # Strides of x
y.stride(0),
y.stride(1), # Strides of y
bias_stride, # Stride of bias (0 if bias is None)
output.stride(0),
output.stride(1), # Strides of output
has_bias, # Flag indicating whether to use bias
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
@triton.jit
def linear_persistent_kernel(
a_ptr, # Pointer to tensor a, shape [M, K]
b_ptr, # Pointer to tensor b, shape [N, K]
c_ptr, # Pointer to output tensor c, shape [M, N]
M, # Number of rows in tensor a
N, # Number of rows in tensor b (number of columns in output c)
K, # Number of columns in both tensor a and tensor b
stride_am, # Stride of tensor a along dimension M (typically K)
stride_ak, # Stride of tensor a along dimension K (typically 1)
stride_bn, # Stride of tensor b along dimension N (typically K)
stride_bk, # Stride of tensor b along dimension K (typically 1)
stride_cm, # Stride of tensor c along dimension M (typically N)
stride_cn, # Stride of tensor c along dimension N (typically 1)
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr, # Block size for K dimension
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
):
# Get current program's 1D index (1D grid)
pid = tl.program_id(0)
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
# Loop over multiple blocks assigned to the current program
for block_index in range(pid, total_blocks, GRID_SIZE):
# Convert 1D block index to 2D coordinates (m_block, n_block)
m_block = block_index // NUM_BLOCKS_N
n_block = block_index % NUM_BLOCKS_N
# Calculate starting indices of the current output block
start_m = m_block * BLOCK_M
start_n = n_block * BLOCK_N
# Create row and column index ranges within the current block
m_indices = start_m + tl.arange(0, BLOCK_M)
n_indices = start_n + tl.arange(0, BLOCK_N)
# Create masks to handle boundaries
m_mask = m_indices < M
n_mask = n_indices < N
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over K dimension with step size BLOCK_K
for k_offset in range(0, K, BLOCK_K):
k_indices = k_offset + tl.arange(0, BLOCK_K)
k_mask = k_indices < K
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[
None, :] * stride_ak
a_vals = tl.load(a_ptrs,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0)
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[
None, :] * stride_bk
b_vals = tl.load(b_ptrs,
mask=n_mask[:, None] & k_mask[None, :],
other=0.0)
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
b_vals_transposed = tl.trans(b_vals)
# Compute matrix multiplication: a_vals × b_vals_transposed
product = tl.dot(a_vals, b_vals_transposed)
acc += product
# Store result to output tensor c
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[
None, :] * stride_cn
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
def linear_persistent(x, y):
"""
Implement matrix multiplication with Triton: x @ y^T
Uses a fixed-size 1D grid
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [N, K]
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[
1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
M, K = x.shape
N, _ = y.shape
# Allocate output tensor (same data type as x)
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
# Define block sizes (can be adjusted based on hardware)
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
# Calculate number of blocks per dimension (ceil division)
num_blocks_m = triton.cdiv(M, BLOCK_M)
num_blocks_n = triton.cdiv(N, BLOCK_N)
# Set fixed 1D grid size
grid_size = driver.active.utils.get_device_properties(
torch.npu.current_device())["num_vectorcore"] // 2
grid = (grid_size, )
# Launch kernel
linear_persistent_kernel[grid](
a_ptr=x,
b_ptr=y,
c_ptr=output,
M=M,
N=N,
K=K,
stride_am=x.stride(0),
stride_ak=x.stride(1),
stride_bn=y.stride(0),
stride_bk=y.stride(1),
stride_cm=output.stride(0),
stride_cn=output.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
GRID_SIZE=grid_size, # Fixed grid size
)
return output
def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}")
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as matmul
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
elif a.ndim == 3 and b.ndim == 2:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 2 and b.ndim == 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
return bmm_batch_invariant(a_expanded, b, out=out)
elif a.ndim == 4 and b.ndim == 4:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch, heads, seq_a, dim_a = a.shape
_, _, dim_b, seq_b = b.shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
# Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [batch, heads, seq_a, seq_b]
result = result_3d.reshape(batch, heads, seq_a, seq_b)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"got shapes {a.shape} and {b.shape}")
def linear_batch_invariant(input_, weight, bias=None):
output = linear_persistent(input_, weight)
if bias is not None:
output = output + bias
return output

View File

@@ -0,0 +1,177 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def mean_kernel(
input_ptr,
output_ptr,
input_stride0,
input_stride1,
input_stride2,
output_stride0,
output_stride1,
M, # size before reduction dim
N, # size of reduction dim
K, # size after reduction dim
BLOCK_SIZE: tl.constexpr,
):
"""
Kernel for computing mean along a single dimension.
Input is viewed as (M, N, K) where N is the dimension being reduced.
"""
# Program ID gives us which output element we're computing
pid = tl.program_id(0)
# Compute output indices
m_idx = pid // K
k_idx = pid % K
# Bounds check
if m_idx >= M or k_idx >= K:
return
# Accumulate sum across reduction dimension
acc = 0.0
for n_start in range(0, N, BLOCK_SIZE):
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
mask = n_offsets < N
# Calculate input indices
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
# Load and accumulate
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
acc += tl.sum(vals)
# Compute mean and store
mean_val = acc / N
output_idx = m_idx * output_stride0 + k_idx * output_stride1
tl.store(output_ptr + output_idx, mean_val)
def mean_dim(
input_: torch.Tensor,
dim: int,
keepdim: bool = False,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""
Triton implementation of torch.mean with single dimension reduction.
Args:
input: Input tensor
dim: Single dimension along which to compute mean
keepdim: Whether to keep the reduced dimension
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
Returns:
Tensor with mean values along specified dimension
"""
# Validate inputs
assert -input_.ndim <= dim < input_.ndim, (
f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions")
# Handle negative dim
if dim < 0:
dim = dim + input_.ndim
# Handle dtype
if dtype is None:
if input_.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
dtype = torch.float32
else:
dtype = input_.dtype
# Convert input to appropriate dtype if needed
if input_.dtype != dtype:
input_ = input_.to(dtype)
# Get input shape and strides
shape = list(input_.shape)
# Calculate dimensions for kernel
M = 1
for i in range(dim):
M *= shape[i]
N = shape[dim]
K = 1
for i in range(dim + 1, len(shape)):
K *= shape[i]
# Reshape input to 3D view (M, N, K)
input_3d = input_.reshape(M, N, K)
# Create output shape
if keepdim:
output_shape = shape.copy()
output_shape[dim] = 1
else:
output_shape = shape[:dim] + shape[dim + 1:]
# Create output tensor
output = torch.empty(output_shape, dtype=dtype, device=input_.device)
# Reshape output for kernel
if keepdim:
output_2d = output.reshape(M, 1, K).squeeze(1)
else:
output_2d = output.reshape(M, K)
# Launch kernel
grid = (M * K, )
BLOCK_SIZE = 1024
mean_kernel[grid](
input_3d,
output_2d,
input_3d.stride(0),
input_3d.stride(1),
input_3d.stride(2),
output_2d.stride(0),
output_2d.stride(1) if output_2d.ndim > 1 else 0,
M,
N,
K,
BLOCK_SIZE,
)
return output
def mean_batch_invariant(
input_: torch.Tensor,
dim: list[int],
keepdim: bool = False,
dtype: torch.dtype = torch.float16,
):
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
if len(dim) == 1:
return mean_dim(input_, dim[0], keepdim=keepdim)
else:
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32
}, ("only float types supported for now")
if len(dim) == 0:
dim = list(range(input_.ndim))
n_elems = 1
for d in dim:
n_elems *= input_.shape[d]
return torch.sum(input_, dim=dim, keepdim=keepdim,
dtype=torch.float32).to(dtype
or input_.dtype) / n_elems

View File

@@ -0,0 +1,153 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from triton.runtime import driver # type: ignore
from vllm.triton_utils import tl, triton
@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_rows, # 新增参数:总行数
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each program handles multiple rows of the input tensor.
"""
pid = tl.program_id(0)
n_programs = tl.num_programs(0)
rows_per_program = (n_rows + n_programs - 1) // n_programs
start_row = pid * rows_per_program
end_row = tl.minimum(start_row + rows_per_program, n_rows)
for row_idx in range(start_row, end_row):
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq = tl.zeros([1], dtype=tl.float32)
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
vals_f32 = vals.to(tl.float32)
sq_vals = vals_f32 * vals_f32
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
# Step 2: Compute RMS (root mean square) in float32
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms
# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
vals_f32 = vals.to(tl.float32)
weight_f32 = weight.to(tl.float32)
output_f32 = vals_f32 * inv_rms * weight_f32
output = output_f32.to(vals.dtype)
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def rms_norm(
input_: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel with fixed grid size.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input_.shape[-1] == weight.shape[0], (
f"Input last dimension ({input_.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})")
# Flatten all dimensions except the last one
original_shape = input_.shape
input_2d = input_.reshape(-1, input_.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d, dtype=input_.dtype)
BLOCK_SIZE = 1024
max_grid_size = driver.active.utils.get_device_properties(
torch.npu.current_device())["num_vectorcore"]
grid = (min(n_rows, max_grid_size), )
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_rows,
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)
def rms_norm_batch_invariant(
input_: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return rms_norm(input_, weight, eps=eps)

View File

@@ -0,0 +1,29 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
def softmax_batch_invariant(input_, dim, dtype=None):
# Compute softmax in a deterministic way
# First subtract max for numerical stability (standard practice)
input_max = torch.amax(input_, dim=dim, keepdim=True)
input_ = input_ - input_max
exp_x = torch.exp(input_)
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
return exp_x / sum_exp_x

View File

@@ -50,6 +50,7 @@ from vllm.v1.worker.workspace import init_workspace_manager
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.batch_invariant import init_batch_invariance
from vllm_ascend.cpu_binding import bind_cpus
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
@@ -453,6 +454,7 @@ class NPUWorker(WorkerBase):
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
init_batch_invariance()
init_distributed_environment(self.parallel_config.world_size,
self.rank, self.distributed_init_method,
self.local_rank, "hccl")