[Feature] implement basic framework for batch invariant (#5517)
### What this PR does / why we need it?
This PR implement the basic framework for batch invariant, please see
https://github.com/vllm-project/vllm-ascend/issues/5487.
### Does this PR introduce _any_ user-facing change?
we reuse the function `vllm_is_batch_invariant` in vllm to judge if
batch invariant is enabled.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Lord_of_Ironhill <suiweiyi@huawei.com>
Signed-off-by: zjchenn <zjchenn@gmail.com>
Signed-off-by: wangx700 <wangxin700@huawei.com>
Co-authored-by: Lord_of_Ironhill <suiweiyi@huawei.com>
Co-authored-by: zjchenn <zjchenn@gmail.com>
Co-authored-by: wangx700 <wangxin700@huawei.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -123,6 +123,7 @@ jobs:
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py
|
||||
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/test_batch_invariant.py
|
||||
|
||||
e2e-2-cards:
|
||||
name: multicard-2
|
||||
|
||||
672
tests/e2e/singlecard/test_batch_invariant.py
Normal file
672
tests/e2e/singlecard/test_batch_invariant.py
Normal file
@@ -0,0 +1,672 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/tests/v1/determinism/test_batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from tests.e2e.conftest import cleanup_dist_env_and_memory
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
# Use a mix of common English text patterns
|
||||
|
||||
prompt_templates = [
|
||||
# Question-answer style
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
# Story/narrative style
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
# Technical/code style
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
# Factual/informative style
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
# Conversational style
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
|
||||
# Pick a random template
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
if max_words < min_words:
|
||||
max_words = min_words
|
||||
target_words = random.randint(min_words, max_words)
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. " *
|
||||
(target_words // 50))
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
using the high-level v1 LLM() API only (no manual batching).
|
||||
|
||||
Strategy:
|
||||
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
|
||||
- Compute a baseline output for the needle prompt with the bs=1 engine.
|
||||
- For many trials, generate a batch (size N) where the needle appears at a
|
||||
random position among random filler prompts using the bs=N engine.
|
||||
- Track how many trials match vs mismatch, and report totals at the end.
|
||||
The test fails if any mismatches occur, but we still dump pass/fail
|
||||
counts.
|
||||
|
||||
Notes:
|
||||
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
||||
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
||||
to produce a more random-sounding phrase, yet remain deterministic by
|
||||
seed.
|
||||
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||
"""
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
model = DEFAULT_MODEL
|
||||
|
||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "144"))
|
||||
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
||||
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
||||
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
||||
|
||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.95"))
|
||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
|
||||
|
||||
# Sampling parameters: longer outputs with a more random-sounding
|
||||
# continuation,but still deterministic due to fixed seed.
|
||||
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
|
||||
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
|
||||
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "35"))
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
seed=20240919,
|
||||
)
|
||||
|
||||
needle_prompt = "There once was a "
|
||||
|
||||
llm = None
|
||||
try:
|
||||
# Engine with bs=1 behavior
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=True,
|
||||
distributed_executor_backend="mp",
|
||||
# Enable for MOE models
|
||||
# enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
baseline_out = llm.generate([needle_prompt], sampling)
|
||||
assert len(baseline_out) == 1
|
||||
assert len(baseline_out[0].outputs) >= 1
|
||||
baseline_text = baseline_out[0].outputs[0].text
|
||||
|
||||
mismatches = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
# Create a batch of size `max_batch_size` and insert the needle at
|
||||
# a random index
|
||||
prompts: list[str] = []
|
||||
batch_size = random.randint(max_batch_size // 2, max_batch_size)
|
||||
needle_pos = random.randint(0, batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(
|
||||
_random_prompt(min_random_prompt, max_random_prompt))
|
||||
|
||||
# Generate with the larger-batch engine
|
||||
outputs = llm.generate(prompts, sampling)
|
||||
# Find the needle output by position
|
||||
needle_output = outputs[needle_pos]
|
||||
assert needle_output.prompt == needle_prompt
|
||||
assert len(needle_output.outputs) >= 1
|
||||
text = needle_output.outputs[0].text
|
||||
|
||||
if text != baseline_text:
|
||||
print(
|
||||
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
||||
mismatches += 1
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, max_batch_size={max_batch_size}")
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
|
||||
|
||||
finally:
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = DEFAULT_MODEL
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||
# all-reduce operations (custom all-reduce may not be deterministic)
|
||||
from vllm_ascend.batch_invariant import vllm_is_batch_invariant
|
||||
|
||||
disable_custom_ar = vllm_is_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(
|
||||
f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})"
|
||||
)
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=True,
|
||||
distributed_executor_backend="mp",
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
print("\n" + "=" * 80)
|
||||
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
|
||||
# BS=N: run prompts in a batch and collect logprobs per step for each
|
||||
# prompt.
|
||||
print("\n" + "=" * 80)
|
||||
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bsN_logprobs_per_prompt = []
|
||||
bsN_tokens_per_prompt = []
|
||||
|
||||
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||
for idx, o in enumerate(outs_batched):
|
||||
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
failed_prompts = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)")
|
||||
failed_prompts.append({
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
failed_prompts.append({
|
||||
"prompt_idx":
|
||||
i,
|
||||
"step":
|
||||
"sampling",
|
||||
"reason":
|
||||
"Different tokens sampled",
|
||||
"prompt_preview":
|
||||
prompts[i][:100],
|
||||
"bs1_tokens":
|
||||
tokens_bs1,
|
||||
"bsN_tokens":
|
||||
tokens_bsN,
|
||||
"bs1_all_logprobs":
|
||||
[logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs":
|
||||
[logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
})
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
failed_prompts.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
# Print which token failed
|
||||
print(
|
||||
f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}"
|
||||
)
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
failed_prompts.append({
|
||||
"prompt_idx":
|
||||
i,
|
||||
"step":
|
||||
t,
|
||||
"reason":
|
||||
f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview":
|
||||
prompts[i][:100],
|
||||
"bs1_tokens":
|
||||
tokens_bs1,
|
||||
"bsN_tokens":
|
||||
tokens_bsN,
|
||||
"bs1_all_logprobs": [
|
||||
logprobs_bs1[s].tolist()
|
||||
for s in range(len(logprobs_bs1))
|
||||
],
|
||||
"bsN_all_logprobs": [
|
||||
logprobs_bsN[s].tolist()
|
||||
for s in range(len(logprobs_bsN))
|
||||
],
|
||||
})
|
||||
break
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
# Print summary of all failures
|
||||
if failed_prompts:
|
||||
print(f"\n{'=' * 80}")
|
||||
fail_msg = (f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts failed")
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for fail in failed_prompts:
|
||||
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
|
||||
print(f" Reason: {fail['reason']}")
|
||||
print(f" Preview: {fail['prompt_preview']}...")
|
||||
|
||||
# Always show the tokens
|
||||
if "bs1_tokens" in fail:
|
||||
print(f" BS=1 tokens: {fail['bs1_tokens']}")
|
||||
if "bsN_tokens" in fail:
|
||||
print(f" BS=N tokens: {fail['bsN_tokens']}")
|
||||
|
||||
if "bs1_all_logprobs" in fail:
|
||||
print(
|
||||
f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:"
|
||||
)
|
||||
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(
|
||||
f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:"
|
||||
)
|
||||
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Fail the test with summary
|
||||
msg = (f"Batch invariance violated in {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts. See output above for details.")
|
||||
pytest.fail(msg)
|
||||
|
||||
|
||||
def test_simple_generation(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
Useful for quick smoke testing and debugging.
|
||||
"""
|
||||
model = DEFAULT_MODEL
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_num_seqs=1,
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="float16",
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=True,
|
||||
distributed_executor_backend="mp",
|
||||
)
|
||||
|
||||
prompt = "The capital of France is"
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Running simple generation test")
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
try:
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
|
||||
assert len(outputs) == 1
|
||||
output_text = outputs[0].outputs[0].text
|
||||
|
||||
print(f"Output: '{output_text}'")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Full completion: '{prompt}{output_text}'")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
finally:
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_logprobs_without_batch_invariance_should_fail(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||
between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
|
||||
doing something useful.
|
||||
|
||||
The test will PASS if we detect differences (proving batch invariance matters).
|
||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||
"""
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = DEFAULT_MODEL
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True,
|
||||
distributed_executor_backend="mp",
|
||||
)
|
||||
|
||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
|
||||
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
||||
prompts: list[str] = []
|
||||
options = [
|
||||
(max(long_min, 1536), max(long_max, 3072)), # very long
|
||||
(max(1024, long_min), max(2048, long_max)), # long
|
||||
(256, 512), # mid
|
||||
(10, 20), # short
|
||||
]
|
||||
|
||||
for _ in range(32):
|
||||
lo, hi = random.choice(options)
|
||||
prompts.append(_random_prompt(lo, hi))
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
print("\n" + "=" * 80)
|
||||
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
|
||||
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
|
||||
print("\n" + "=" * 80)
|
||||
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bsN_logprobs_per_prompt = []
|
||||
bsN_tokens_per_prompt = []
|
||||
|
||||
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||
for idx, o in enumerate(outs_batched):
|
||||
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
differences_found = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)")
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
||||
f"Token {t}: max_diff={max_diff:.6e}")
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
break
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
# Print summary
|
||||
print(f"\n{'=' * 80}")
|
||||
if differences_found:
|
||||
success_msg = (
|
||||
f"✓ SUCCESS: Batch invariance is doing something! "
|
||||
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
||||
f"with differences when batch invariance was DISABLED.")
|
||||
print(success_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for diff in differences_found:
|
||||
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
|
||||
print(f" Reason: {diff['reason']}")
|
||||
print(f" Preview: {diff['prompt_preview']}...")
|
||||
if "bs1_tokens" in diff:
|
||||
print(f" BS=1 tokens: {diff['bs1_tokens']}")
|
||||
if "bsN_tokens" in diff:
|
||||
print(f" BS=N tokens: {diff['bsN_tokens']}")
|
||||
print(f"{'=' * 80}\n")
|
||||
# Test PASSES because we found differences (batch invariance matters!)
|
||||
return
|
||||
else:
|
||||
# Test FAILS because everything matched even without batch invariance
|
||||
fail_msg = (
|
||||
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
||||
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
||||
f"This suggests batch invariance might not be necessary, "
|
||||
f"or the test needs more sensitive prompts.")
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}\n")
|
||||
pytest.fail(fail_msg)
|
||||
82
vllm_ascend/batch_invariant.py
Normal file
82
vllm_ascend/batch_invariant.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import os
|
||||
|
||||
import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm_ascend.ops.triton.batch_invariant.matmul import (
|
||||
addmm_batch_invariant, bmm_batch_invariant, linear_batch_invariant,
|
||||
matmul_batch_invariant, mm_batch_invariant)
|
||||
|
||||
|
||||
def override_envs_for_invariance():
|
||||
# TODO(Ronald) set attntion backend to deterministic mode
|
||||
|
||||
# enabling NZ mode introduces NZ format input to the triton operator,
|
||||
# resulting in accuracy anomalies.
|
||||
os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0"
|
||||
|
||||
# communication determinism settings
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
os.environ["LCCL_DETERMINISTIC"] = "1"
|
||||
|
||||
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_LIB
|
||||
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
|
||||
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
"""
|
||||
Initialize batch-invariant mode for vLLM on Ascend NPU.
|
||||
|
||||
This function:
|
||||
1. Sets environment variables for deterministic computation
|
||||
2. Registers batch-invariant implementations for torch operators
|
||||
3. Enables batch-invariant flash attention
|
||||
|
||||
Call this function early in your application, or set VLLM_BATCH_INVARIANT=1
|
||||
environment variable to enable automatically.
|
||||
"""
|
||||
if vllm_is_batch_invariant():
|
||||
if HAS_TRITON:
|
||||
logger.info(
|
||||
"Enabling batch-invariant mode for vLLM on Ascend NPU.", )
|
||||
override_envs_for_invariance()
|
||||
enable_batch_invariant_mode()
|
||||
else:
|
||||
logger.warning(
|
||||
"Batch-invariant mode requested but Triton is not available."
|
||||
"skipping batch-invariant initialization.", )
|
||||
0
vllm_ascend/ops/triton/batch_invariant/__init__.py
Normal file
0
vllm_ascend/ops/triton/batch_invariant/__init__.py
Normal file
403
vllm_ascend/ops/triton/batch_invariant/matmul.py
Normal file
403
vllm_ascend/ops/triton/batch_invariant/matmul.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from triton.runtime import driver # type: ignore
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_bias_persistent_kernel(
|
||||
# Input tensor pointers
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
bias_ptr,
|
||||
output_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Stride information
|
||||
stride_xm,
|
||||
stride_xk, # Strides of x: [M, K]
|
||||
stride_yk,
|
||||
stride_yn, # Strides of y: [K, N]
|
||||
stride_bias, # Stride of bias: [N]
|
||||
stride_outm,
|
||||
stride_outn, # Strides of output: [M, N]
|
||||
# Whether to use bias
|
||||
has_bias: tl.constexpr,
|
||||
# Block sizes
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(0) # Row block ID
|
||||
pid_n = tl.program_id(1) # Column block ID
|
||||
|
||||
# Calculate the starting position of the current block in the matrix
|
||||
rm_start = pid_m * BLOCK_M
|
||||
rn_start = pid_n * BLOCK_N
|
||||
|
||||
# Create index ranges
|
||||
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
|
||||
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
|
||||
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
|
||||
|
||||
# Initialize accumulator to 0
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
# Loop over the K dimension, processing BLOCK_K elements per iteration
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
k_start = k * BLOCK_K
|
||||
# Calculate pointer offsets for x (row-major)
|
||||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] +
|
||||
k_start) * stride_xk
|
||||
# Calculate pointer offsets for y (row-major)
|
||||
y_ptrs = y_ptr + (rk[:, None] +
|
||||
k_start) * stride_yk + rn[None, :] * stride_yn
|
||||
|
||||
# Create masks to prevent out-of-bounds access
|
||||
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
|
||||
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
|
||||
|
||||
# Load data chunks from global memory
|
||||
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
|
||||
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute matrix multiplication accumulation
|
||||
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
|
||||
|
||||
# Add bias if the has_bias flag is set
|
||||
if has_bias:
|
||||
# Load bias values (broadcast to all rows)
|
||||
bias_ptrs = bias_ptr + rn * stride_bias
|
||||
bias_mask = rn < N
|
||||
bias_vals = tl.load(bias_ptrs, mask=bias_mask,
|
||||
other=0.0).to(tl.float32)
|
||||
# Add bias to accumulator (automatic broadcasting)
|
||||
acc += bias_vals[None, :]
|
||||
|
||||
# Calculate output pointer positions
|
||||
out_ptrs = output_ptr + rm[:,
|
||||
None] * stride_outm + rn[None, :] * stride_outn
|
||||
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||
|
||||
# Store result to global memory
|
||||
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
|
||||
|
||||
|
||||
def matmul_persistent(x, y, bias=None):
|
||||
"""
|
||||
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [K, N]
|
||||
bias: torch.Tensor, shape [N] or None
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||||
|
||||
M, K = x.shape
|
||||
_, N = y.shape
|
||||
# Validate bias shape (if not None)
|
||||
if bias is not None:
|
||||
assert bias.dim() == 1, "bias must be a 1D tensor"
|
||||
assert y.shape[1] == bias.shape[
|
||||
0], f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||||
|
||||
# Allocate output tensor (same data type as x)
|
||||
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
|
||||
|
||||
# Define block sizes (can be adjusted based on hardware)
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
|
||||
|
||||
# Calculate grid size (one thread per block)
|
||||
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
|
||||
|
||||
# Handle case when bias is None
|
||||
if bias is None:
|
||||
# Create a dummy bias tensor (will not be used as has_bias=False)
|
||||
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
has_bias = False
|
||||
bias_stride = 0
|
||||
bias_to_pass = dummy_bias
|
||||
else:
|
||||
has_bias = True
|
||||
bias_stride = bias.stride(0)
|
||||
bias_to_pass = bias
|
||||
# Launch kernel
|
||||
matmul_bias_persistent_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
bias_to_pass,
|
||||
output, # Input/Output tensors
|
||||
M,
|
||||
N,
|
||||
K, # Matrix dimensions
|
||||
x.stride(0),
|
||||
x.stride(1), # Strides of x
|
||||
y.stride(0),
|
||||
y.stride(1), # Strides of y
|
||||
bias_stride, # Stride of bias (0 if bias is None)
|
||||
output.stride(0),
|
||||
output.stride(1), # Strides of output
|
||||
has_bias, # Flag indicating whether to use bias
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def linear_persistent_kernel(
|
||||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||||
M, # Number of rows in tensor a
|
||||
N, # Number of rows in tensor b (number of columns in output c)
|
||||
K, # Number of columns in both tensor a and tensor b
|
||||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||||
):
|
||||
# Get current program's 1D index (1D grid)
|
||||
pid = tl.program_id(0)
|
||||
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
|
||||
|
||||
# Loop over multiple blocks assigned to the current program
|
||||
for block_index in range(pid, total_blocks, GRID_SIZE):
|
||||
# Convert 1D block index to 2D coordinates (m_block, n_block)
|
||||
m_block = block_index // NUM_BLOCKS_N
|
||||
n_block = block_index % NUM_BLOCKS_N
|
||||
|
||||
# Calculate starting indices of the current output block
|
||||
start_m = m_block * BLOCK_M
|
||||
start_n = n_block * BLOCK_N
|
||||
|
||||
# Create row and column index ranges within the current block
|
||||
m_indices = start_m + tl.arange(0, BLOCK_M)
|
||||
n_indices = start_n + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Create masks to handle boundaries
|
||||
m_mask = m_indices < M
|
||||
n_mask = n_indices < N
|
||||
|
||||
# Initialize accumulator to 0
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
# Loop over K dimension with step size BLOCK_K
|
||||
for k_offset in range(0, K, BLOCK_K):
|
||||
k_indices = k_offset + tl.arange(0, BLOCK_K)
|
||||
k_mask = k_indices < K
|
||||
|
||||
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
|
||||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[
|
||||
None, :] * stride_ak
|
||||
a_vals = tl.load(a_ptrs,
|
||||
mask=m_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
|
||||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[
|
||||
None, :] * stride_bk
|
||||
b_vals = tl.load(b_ptrs,
|
||||
mask=n_mask[:, None] & k_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
|
||||
b_vals_transposed = tl.trans(b_vals)
|
||||
|
||||
# Compute matrix multiplication: a_vals × b_vals_transposed
|
||||
product = tl.dot(a_vals, b_vals_transposed)
|
||||
acc += product
|
||||
# Store result to output tensor c
|
||||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[
|
||||
None, :] * stride_cn
|
||||
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
|
||||
|
||||
|
||||
def linear_persistent(x, y):
|
||||
"""
|
||||
Implement matrix multiplication with Triton: x @ y^T
|
||||
Uses a fixed-size 1D grid
|
||||
|
||||
Parameters:
|
||||
x: torch.Tensor, shape [M, K]
|
||||
y: torch.Tensor, shape [N, K]
|
||||
|
||||
Returns:
|
||||
output: torch.Tensor, shape [M, N]
|
||||
"""
|
||||
# Validate input shapes
|
||||
assert x.dim() == 2, "x must be a 2D tensor"
|
||||
assert y.dim() == 2, "y must be a 2D tensor"
|
||||
assert x.shape[1] == y.shape[
|
||||
1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||||
|
||||
M, K = x.shape
|
||||
N, _ = y.shape
|
||||
|
||||
# Allocate output tensor (same data type as x)
|
||||
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
|
||||
|
||||
# Define block sizes (can be adjusted based on hardware)
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
|
||||
|
||||
# Calculate number of blocks per dimension (ceil division)
|
||||
num_blocks_m = triton.cdiv(M, BLOCK_M)
|
||||
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
# Set fixed 1D grid size
|
||||
grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size, )
|
||||
|
||||
# Launch kernel
|
||||
linear_persistent_kernel[grid](
|
||||
a_ptr=x,
|
||||
b_ptr=y,
|
||||
c_ptr=output,
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
stride_am=x.stride(0),
|
||||
stride_ak=x.stride(1),
|
||||
stride_bn=y.stride(0),
|
||||
stride_bk=y.stride(1),
|
||||
stride_cm=output.stride(0),
|
||||
stride_cn=output.stride(1),
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
|
||||
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
|
||||
GRID_SIZE=grid_size, # Fixed grid size
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mm_batch_invariant(a, b):
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def bmm_batch_invariant(a, b, *, out=None):
|
||||
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
||||
# Process each batch separately with our persistent kernel
|
||||
if a.ndim == 3 and b.ndim == 3:
|
||||
results = []
|
||||
for i in range(a.shape[0]):
|
||||
results.append(matmul_persistent(a[i], b[i]))
|
||||
result = torch.stack(results, dim=0)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
return matmul_persistent(a, b, bias=bias)
|
||||
|
||||
|
||||
def matmul_batch_invariant(a, b, *, out=None):
|
||||
# torch.matmul can handle various dimensions
|
||||
# For 2D x 2D, it's the same as matmul
|
||||
if a.ndim == 2 and b.ndim == 2:
|
||||
result = matmul_persistent(a, b)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 3 and b.ndim == 3:
|
||||
# Handle batched case like bmm
|
||||
return bmm_batch_invariant(a, b, out=out)
|
||||
elif a.ndim == 3 and b.ndim == 2:
|
||||
# Handle 3D x 2D: common for linear layers
|
||||
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
|
||||
# Reshape to 2D, do mm, reshape back
|
||||
batch, seq, hidden = a.shape
|
||||
a_2d = a.reshape(-1, hidden)
|
||||
result_2d = matmul_persistent(a_2d, b)
|
||||
result = result_2d.reshape(batch, seq, -1)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 2 and b.ndim == 3:
|
||||
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
|
||||
# By broadcasting `a` to 3D, we can reuse the batched matrix
|
||||
# multiplication logic.
|
||||
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
|
||||
return bmm_batch_invariant(a_expanded, b, out=out)
|
||||
elif a.ndim == 4 and b.ndim == 4:
|
||||
# Handle 4D attention tensors: [batch, heads, seq, dim]
|
||||
# Reshape to 3D, process, reshape back
|
||||
batch, heads, seq_a, dim_a = a.shape
|
||||
_, _, dim_b, seq_b = b.shape
|
||||
|
||||
# Reshape to [batch*heads, seq_a, dim_a]
|
||||
a_3d = a.reshape(batch * heads, seq_a, dim_a)
|
||||
b_3d = b.reshape(batch * heads, dim_b, seq_b)
|
||||
|
||||
# Do batched matmul
|
||||
result_3d = bmm_batch_invariant(a_3d, b_3d)
|
||||
|
||||
# Reshape back to [batch, heads, seq_a, seq_b]
|
||||
result = result_3d.reshape(batch, heads, seq_a, seq_b)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||||
f"got shapes {a.shape} and {b.shape}")
|
||||
|
||||
|
||||
def linear_batch_invariant(input_, weight, bias=None):
|
||||
output = linear_persistent(input_, weight)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
177
vllm_ascend/ops/triton/batch_invariant/mean.py
Normal file
177
vllm_ascend/ops/triton/batch_invariant/mean.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_stride0,
|
||||
input_stride1,
|
||||
input_stride2,
|
||||
output_stride0,
|
||||
output_stride1,
|
||||
M, # size before reduction dim
|
||||
N, # size of reduction dim
|
||||
K, # size after reduction dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for computing mean along a single dimension.
|
||||
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||
"""
|
||||
# Program ID gives us which output element we're computing
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Compute output indices
|
||||
m_idx = pid // K
|
||||
k_idx = pid % K
|
||||
|
||||
# Bounds check
|
||||
if m_idx >= M or k_idx >= K:
|
||||
return
|
||||
# Accumulate sum across reduction dimension
|
||||
acc = 0.0
|
||||
for n_start in range(0, N, BLOCK_SIZE):
|
||||
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
acc += tl.sum(vals)
|
||||
|
||||
# Compute mean and store
|
||||
mean_val = acc / N
|
||||
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(
|
||||
input_: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Single dimension along which to compute mean
|
||||
keepdim: Whether to keep the reduced dimension
|
||||
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
|
||||
|
||||
Returns:
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert -input_.ndim <= dim < input_.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions")
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
dim = dim + input_.ndim
|
||||
# Handle dtype
|
||||
if dtype is None:
|
||||
if input_.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = input_.dtype
|
||||
# Convert input to appropriate dtype if needed
|
||||
if input_.dtype != dtype:
|
||||
input_ = input_.to(dtype)
|
||||
|
||||
# Get input shape and strides
|
||||
shape = list(input_.shape)
|
||||
# Calculate dimensions for kernel
|
||||
M = 1
|
||||
for i in range(dim):
|
||||
M *= shape[i]
|
||||
|
||||
N = shape[dim]
|
||||
|
||||
K = 1
|
||||
for i in range(dim + 1, len(shape)):
|
||||
K *= shape[i]
|
||||
|
||||
# Reshape input to 3D view (M, N, K)
|
||||
input_3d = input_.reshape(M, N, K)
|
||||
|
||||
# Create output shape
|
||||
if keepdim:
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input_.device)
|
||||
|
||||
# Reshape output for kernel
|
||||
if keepdim:
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||
else:
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
input_3d,
|
||||
output_2d,
|
||||
input_3d.stride(0),
|
||||
input_3d.stride(1),
|
||||
input_3d.stride(2),
|
||||
output_2d.stride(0),
|
||||
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mean_batch_invariant(
|
||||
input_: torch.Tensor,
|
||||
dim: list[int],
|
||||
keepdim: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
if len(dim) == 1:
|
||||
return mean_dim(input_, dim[0], keepdim=keepdim)
|
||||
else:
|
||||
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32
|
||||
}, ("only float types supported for now")
|
||||
if len(dim) == 0:
|
||||
dim = list(range(input_.ndim))
|
||||
n_elems = 1
|
||||
for d in dim:
|
||||
n_elems *= input_.shape[d]
|
||||
return torch.sum(input_, dim=dim, keepdim=keepdim,
|
||||
dtype=torch.float32).to(dtype
|
||||
or input_.dtype) / n_elems
|
||||
153
vllm_ascend/ops/triton/batch_invariant/rmsnorm.py
Normal file
153
vllm_ascend/ops/triton/batch_invariant/rmsnorm.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from triton.runtime import driver # type: ignore
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_kernel(
|
||||
input_ptr,
|
||||
weight_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_rows, # 新增参数:总行数
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute RMS normalization along the last dimension of a 2D tensor.
|
||||
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
Each program handles multiple rows of the input tensor.
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
n_programs = tl.num_programs(0)
|
||||
|
||||
rows_per_program = (n_rows + n_programs - 1) // n_programs
|
||||
start_row = pid * rows_per_program
|
||||
end_row = tl.minimum(start_row + rows_per_program, n_rows)
|
||||
|
||||
for row_idx in range(start_row, end_row):
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Compute sum of squares in float32 to avoid overflow
|
||||
sum_sq = tl.zeros([1], dtype=tl.float32)
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
sq_vals = vals_f32 * vals_f32
|
||||
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
|
||||
|
||||
# Step 2: Compute RMS (root mean square) in float32
|
||||
mean_sq = sum_sq / n_cols
|
||||
rms = tl.sqrt(mean_sq + eps)
|
||||
inv_rms = 1.0 / rms
|
||||
|
||||
# Step 3: Normalize and apply weight
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
weight_f32 = weight.to(tl.float32)
|
||||
output_f32 = vals_f32 * inv_rms * weight_f32
|
||||
output = output_f32.to(vals.dtype)
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def rms_norm(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS normalization using Triton kernel with fixed grid size.
|
||||
|
||||
RMS Norm normalizes the input by the root mean square and scales by weight:
|
||||
output = input / sqrt(mean(input^2) + eps) * weight
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
Tensor with RMS normalization applied along the last dimension
|
||||
"""
|
||||
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||
assert input_.shape[-1] == weight.shape[0], (
|
||||
f"Input last dimension ({input_.shape[-1]}) must match "
|
||||
f"weight dimension ({weight.shape[0]})")
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input_.shape
|
||||
input_2d = input_.reshape(-1, input_.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
weight = weight.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
output = torch.empty_like(input_2d, dtype=input_.dtype)
|
||||
BLOCK_SIZE = 1024
|
||||
max_grid_size = driver.active.utils.get_device_properties(
|
||||
torch.npu.current_device())["num_vectorcore"]
|
||||
|
||||
grid = (min(n_rows, max_grid_size), )
|
||||
|
||||
_rms_norm_kernel[grid](
|
||||
input_2d,
|
||||
weight,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
def rms_norm_batch_invariant(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Batch-invariant wrapper for RMS normalization.
|
||||
|
||||
This function provides a deterministic, batch-invariant implementation
|
||||
of RMS normalization for use with the batch_invariant mode.
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
RMS normalized tensor
|
||||
"""
|
||||
return rms_norm(input_, weight, eps=eps)
|
||||
29
vllm_ascend/ops/triton/batch_invariant/softmax.py
Normal file
29
vllm_ascend/ops/triton/batch_invariant/softmax.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import torch
|
||||
|
||||
|
||||
def softmax_batch_invariant(input_, dim, dtype=None):
|
||||
# Compute softmax in a deterministic way
|
||||
# First subtract max for numerical stability (standard practice)
|
||||
input_max = torch.amax(input_, dim=dim, keepdim=True)
|
||||
input_ = input_ - input_max
|
||||
exp_x = torch.exp(input_)
|
||||
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
|
||||
return exp_x / sum_exp_x
|
||||
@@ -50,6 +50,7 @@ from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.batch_invariant import init_batch_invariance
|
||||
from vllm_ascend.cpu_binding import bind_cpus
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
@@ -453,6 +454,7 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
init_batch_invariance()
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
self.rank, self.distributed_init_method,
|
||||
self.local_rank, "hccl")
|
||||
|
||||
Reference in New Issue
Block a user