[main] remove dbo code (#3712)

### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-10-25 15:53:01 +08:00
committed by GitHub
parent d9cdc65854
commit e5676fc36e
26 changed files with 69 additions and 1588 deletions

View File

@@ -1,52 +0,0 @@
import os
import time
from vllm import LLM, SamplingParams
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# enable dual-batch overlap for vllm ascend
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
# Sample prompts.
prompts = ["The president of the United States is"] * 41
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
def main():
# Create an LLM.
llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic",
enforce_eager=True,
tensor_parallel_size=2,
max_model_len=4096,
trust_remote_code=True,
enable_expert_parallel=True,
additional_config={
"torchair_graph_config": {
"enabled": False
},
"ascend_scheduler_config": {
"enabled": True
},
})
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output.
time.sleep(10)
if __name__ == "__main__":
main()

View File

@@ -623,11 +623,8 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
@patch("torch.npu.stream")
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
mock_get_multistream_comm_context, mock_npu_stream,
mock_get_forward_context):
B = 2
N = self.impl.num_kv_heads
@@ -651,8 +648,6 @@ class TestAscendMLAImpl(TestBase):
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank), None
]
mock_get_multistream_comm_context.return_value = None
mock_get_forward_context.return_value = MagicMock(capturing=False)
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
@@ -660,18 +655,3 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)
self.impl.enable_kv_nz = False
attn_metadata.attn_state = None
mock_return_value = MagicMock()
mock_get_multistream_comm_context.return_value = mock_return_value
mock_return_value.before_comm_event = MagicMock()
mock_return_value.comm_stream = MagicMock()
mock_npu_stream.return_value = MagicMock()
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)

View File

@@ -1,32 +0,0 @@
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
MSEventKey)
class Testbase(TestBase):
def test_ms_event_key(self):
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
def test_ms_attention_metadata_split_config_default(self):
config = MSAttentionMetadataSplitConfig()
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
def test_ms_attention_metadata_split_config_custom(self):
config = MSAttentionMetadataSplitConfig(
num_micro_batches=4,
min_total_tokens_to_split=512,
min_prefill_tokens_to_split=128)
self.assertEqual(config.num_micro_batches, 4)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)

View File

@@ -1,47 +0,0 @@
import pytest
from pytest_mock import MockFixture
from tests.ut.base import PytestBase
from vllm_ascend.multistream.decorator import set_multistream_support
class Context:
def __init__(self, attn_metadata=None):
self.attn_metadata = attn_metadata
class TestDecorator(PytestBase):
@pytest.mark.parametrize(
'layer_context, microbatch_context, expected_metadata', [
((-1, None, None), -1, {
"original": True
}),
((-1, None, None), 0, {
"original": True
}),
((0, None, None), -1, {
"original": True
}),
((0, None, [{
"new": True
}]), 0, {
"new": True
}),
])
def test_decorator(self, mocker: MockFixture, layer_context,
microbatch_context, expected_metadata):
def context_func():
return Context(attn_metadata={"original": True})
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
return_value=layer_context)
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
return_value=microbatch_context)
context = set_multistream_support()(context_func)()
assert context.attn_metadata == expected_metadata

View File

@@ -1,198 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import PytestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
MultiStreamPreTransformerLayer)
from vllm_ascend.multistream.metadata import MultiStreamMetadata
# === fixture: mock tensor input ===
@pytest.fixture
def input_tensors():
return [torch.randn(2, 128), torch.randn(2, 128)]
# === mock get_forward_context ===
class DummyContext:
def __init__(self, attn_metadata):
self.attn_metadata = attn_metadata
class TestMultiStreamPreTransformerLayer(PytestBase):
# === test when multistream_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta")
layer = MultiStreamPreTransformerLayer(multistream_metadata=None)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "dummy_meta"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when attn_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata=None)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out is None
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=False (no split needed) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.split_micro_batch.return_value = (False, "same_attn",
input_tensors, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "same_attn"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=True (split occurred) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
split_inputs = [[t[:1], t[1:]] for t in input_tensors]
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.start_layer = 2
dummy_metadata.split_micro_batch.return_value = (True,
["attn1", "attn2"],
split_inputs, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == ["attn1", "attn2"]
assert input_out == split_inputs
mock_set_ctx.assert_called_once_with(2, dummy_metadata,
["attn1", "attn2"])
class TestMultiStreamPostTransformerLayer(PytestBase):
def test_post_forward_metadata_none(self, input_tensors):
layer = MultiStreamPostTransformerLayer(multistream_metadata=None)
output = layer.forward(input_tensors)
assert output == input_tensors
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.ms_config = None
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
assert output == input_tensors
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx,
input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
5, # layer_index
dummy_metadata, # ms_metadata
"dummy_attn_metadata" # ms_attn_metadata
)
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
# check wait_event
dummy_metadata.try_wait_event.assert_called_once_with(
9, # end_layer - 1
3, # num_micro_batches - 1
MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx,
mock_get_ctx, input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
3, # layer_index
dummy_metadata,
"dummy_attn_metadata")
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors, wait_layer_index=7)
dummy_metadata.try_wait_event.assert_called_once_with(
7, 3, MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"

View File

@@ -1,246 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamMetadata,
MultiStreamStepMetadata,
split_micro_batches_tensors)
class TestMetaData(TestBase):
def setUp(self):
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
self.test_tensors = torch.randn(100, 1024)
self.test_tensors_dict = {
'query': torch.randn(100, 1024),
'key': torch.randn(100, 1024),
'value': torch.randn(100, 1024)
}
self.split_index = 50
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
self.metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
def test_split_micro_batches_tensors(self):
test_tensors_list_res = split_micro_batches_tensors(
self.test_tensors_list, self.split_index)
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
self.split_index)
keys = ['query', 'key', 'value']
test_tensors_dict_res = split_micro_batches_tensors(
self.test_tensors_dict, self.split_index, keys)
for i in range(3):
self.assertEqual(len(test_tensors_list_res[i][0]),
self.split_index)
self.assertEqual(
len(test_tensors_list_res[i][0]) +
len(test_tensors_list_res[i][1]), 100)
self.assertEqual(len(test_tensors_res[0]), self.split_index)
self.assertEqual(
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
for key in keys:
self.assertEqual(len(test_tensors_dict_res[0][key]),
self.split_index)
self.assertEqual(
len(test_tensors_dict_res[0][key]) +
len(test_tensors_dict_res[1][key]), 100)
def test_default_init_multistream_step_metadata(self):
metadata = MultiStreamStepMetadata()
self.assertIsNone(metadata.comm_stream)
self.assertIsNone(metadata.before_comm_event)
self.assertIsNone(metadata.after_comm_event)
def test_custom_init_multistream_step_metadata(self):
mockStream = MagicMock(spec=torch.npu.Stream)
mockEvent1 = MagicMock(spec=torch.npu.Event)
mockEvent2 = MagicMock(spec=torch.npu.Event)
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
self.assertEqual(metadata.comm_stream, mockStream)
self.assertEqual(metadata.before_comm_event, mockEvent1)
self.assertEqual(metadata.after_comm_event, mockEvent2)
def test_default_init_multistream_config(self):
config = MultiStreamConfig()
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.imbalance_ratio, 0.1)
def test_custom_init_multistream_config(self):
config = MultiStreamConfig(512, 128, 1, 0.2)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)
self.assertEqual(config.num_micro_batches, 1)
self.assertEqual(config.imbalance_ratio, 0.2)
def test_init_multistream_metadata(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock()]
multistream_config = MagicMock(spec=MultiStreamConfig)
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertEqual(metadata.calculate_stream, mock_stream)
self.assertEqual(metadata.communicate_stream, mock_stream)
self.assertEqual(metadata.start_layer, 1)
self.assertEqual(metadata.end_layer, 3)
self.assertEqual(metadata.ms_config, multistream_config)
self.assertTrue(metadata.causal_lm)
def test_build_events(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
with patch('torch.npu.Event', return_value=mock_event):
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MultiStreamConfig(
num_micro_batches=2,
min_total_tokens_to_split=256,
min_prefill_tokens_to_split=64)
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
expected_events = {
0: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
1: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
2: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
}
}
self.assertEqual(metadata.ms_events, expected_events)
def test_build_ms_split_config(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
multistream_config.num_micro_batches = 2
multistream_config.min_total_tokens_to_split = 256
multistream_config.min_prefill_tokens_to_split = 64
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertIsNotNone(metadata.ms_split_config)
self.assertEqual(metadata.ms_split_config.num_micro_batches,
multistream_config.num_micro_batches)
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
multistream_config.min_total_tokens_to_split)
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
multistream_config.min_prefill_tokens_to_split)
def test_try_wait_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_wait_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.wait.assert_called_once()
def test_try_record_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_record_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.record.assert_called_once()
def test_merge_batches_none_input(self):
input_tensors = None
result = self.metadata.merge_micro_batches(input_tensors)
self.assertIsNone(result)
def test_merge_batches_single_tensor_input(self):
input_tensors = [torch.tensor([1, 2, 3])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 1)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
def test_merge_batches_list_of_tensors_input(self):
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertEqual(result, input_tensors)
def test_merge_batches_nested_list_input(self):
input_tensors = [[torch.tensor([1, 2]),
torch.tensor([3, 4])],
[torch.tensor([5, 6]),
torch.tensor([7, 8])]]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))

View File

@@ -1,147 +0,0 @@
from unittest.mock import MagicMock
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
model_input_split_v1_mla_attn,
split_attn_int_type,
split_attn_tensor_type)
class TestMsSplit(TestBase):
def test_decode_only(self):
result = compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.DecodeOnly,
num_tokens=10)
self.assertEqual(result, [5, 5])
def test_perfect_balance(self):
query_lens = [2, 3, 5]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [5, 2])
def test_imbalance(self):
query_lens = [1, 2, 3, 4]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_query_lens_none(self):
with self.assertRaises(AssertionError):
compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
def test_empty_query_lens(self):
query_lens: list[int] = []
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_single_query_len(self):
query_lens = [10]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_split_attn_tensor_type_middle(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 3
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_start(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 0
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_end(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 5
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_empty_tensor(self):
input_tensor = torch.tensor([])
index = 0
expected_result = [torch.tensor([]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_int_type_index_greater_than_var(self):
var = 5
index = 10
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_equal_to_var(self):
var = 5
index = 5
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_less_than_var(self):
var = 10
index = 5
expected_result = [5, 5]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_zero(self):
var = 10
index = 0
expected_result = [0, 10]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_var_zero(self):
var = 0
index = 5
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_both_zero(self):
var = 0
index = 0
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_v1_mla_attn_input_none(self):
attn_metadata = None
ascendMLAPrefillMetadata = MagicMock()
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
result = model_input_split_v1_mla_attn(attn_metadata,
ascendMLAPrefillMetadata,
ms_split_config)
self.assertEqual(result, [None])

View File

@@ -210,9 +210,6 @@ class AscendMetadata:
# (num_tokens,)
slot_mapping: torch.Tensor = None
# *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False
prefill: Optional[AscendMetadataForPrefill] = None
decode_meta: Optional[AscendMetadataForDecode] = None
@@ -371,7 +368,6 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
num_prefills=num_prefills,
num_decodes=num_decodes,
prefill=prefill_metadata,

View File

@@ -36,9 +36,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
@@ -184,7 +181,6 @@ class AscendMLAMetadata:
decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -195,17 +191,6 @@ class AscendMLAMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLAMetadata"]:
"""Split metadata for multi-stream with AscendMLAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendMLAMetadata)
@@ -538,7 +523,6 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
def build_for_graph_capture(
@@ -1158,13 +1142,7 @@ class AscendMLAImpl(MLAAttentionImpl):
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, k_nope, **common_kwargs)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
@@ -1423,12 +1401,7 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[:num_decode_tokens] = output_decode
if prefill_preprocess_res is not None:
@@ -1445,18 +1418,10 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None:
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
@@ -1465,16 +1430,7 @@ class AscendMLAImpl(MLAAttentionImpl):
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
@@ -1719,18 +1675,9 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_pcp[:, i])
attn_output = attn_out_g
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj(attn_output)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
# TODO use update op to replace this
def _update_out_and_lse(
self,
out: torch.Tensor,

View File

@@ -17,11 +17,8 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@@ -138,7 +135,6 @@ class AscendSFAMetadata:
decode: Optional[AscendSFADecodeMetadata] = None
prefill: Optional[AscendSFAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -149,17 +145,6 @@ class AscendSFAMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFAMetadata"]:
"""Split metadata for multi-stream with AscendSFAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendSFAMetadata)
@@ -434,7 +419,6 @@ class AscendSFAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)

View File

@@ -91,8 +91,6 @@ class AscendCommonAttentionMetadata:
attn_state: Any = None
enable_dbo_across_dp: bool = False
is_only_prefill: bool = False
graph_pad_size: int = -1

View File

@@ -82,9 +82,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
),
# Whether to enable DBO feature for deepseek model.
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when
# running vllm ascend in production environment.
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":

View File

@@ -1,29 +0,0 @@
from dataclasses import dataclass
from enum import Enum
class MSEventKey(Enum):
ATTN_COM_FINISH = 0
ATTN_AR_FINISH = 1
FFN_COM_FINISH = 2
FFN_AR_FINISH = 3
# events for MOE dispatch and combine
MOE_BEFORE_COMM = 4
MOE_AFTER_COMM = 5
# events for shared expert
MOE_SE_COMM_FINISH = 6
MOE_SE_COMP_FINISH = 7
MOE_GATE_FINISH = 8
@dataclass
class MSAttentionMetadataSplitConfig:
"""
micro batch split config for split attention metadata
"""
# micro batch num
num_micro_batches: int = 2
# split micro batches only when total tokens >= min_total_tokens_to_split
min_total_tokens_to_split: int = 256
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
min_prefill_tokens_to_split: int = 64

View File

@@ -1,67 +0,0 @@
from contextlib import contextmanager
from typing import Any
_ms_comm_context: Any = None
_cur_micro_batch_num: int = -1
_ms_layer_index_context: int = -1
_ms_metadata_context: Any = None
_ms_attn_metadata_context: Any = None
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
attn_metadata: Any):
"""
set multistream layer context before transformer layers
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = start_layer
_ms_metadata_context = ms_metadata
_ms_attn_metadata_context = attn_metadata
def reset_multistream_layer_context():
"""
reset multistream layer context
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = -1
_ms_metadata_context = None
_ms_attn_metadata_context = None
def get_multistream_layer_context():
"""
get multistream layer context
"""
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
def advance_step_multistream_layer_context():
"""
advance multistream layer index context
"""
global _ms_layer_index_context
_ms_layer_index_context += 1
def get_multistream_comm_context() -> Any:
"""Get the current comm forward context."""
return _ms_comm_context
def get_multistream_microbatch_context() -> int:
return _cur_micro_batch_num
@contextmanager
def set_multistream_context(context: Any, micro_batch_num: int):
"""A context manager that stores the current comm forward context,
can be attention metadata, etc."""
global _ms_comm_context, _cur_micro_batch_num
_ms_comm_context = context
_cur_micro_batch_num = micro_batch_num
try:
yield
finally:
_ms_comm_context = None
_cur_micro_batch_num = -1

View File

@@ -1,22 +0,0 @@
from .context import (get_multistream_layer_context,
get_multistream_microbatch_context)
# vllm v1 use get_forward_context to get the attn_metadata,
# we can use this decorator to update the attn metadata
def set_multistream_support():
def decorator(func):
def wrapper():
context = func()
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
)
micro_batch_num = get_multistream_microbatch_context()
if layer_index != -1 and micro_batch_num != -1:
context.attn_metadata = attn_metadata[micro_batch_num]
return context
return wrapper
return decorator

View File

@@ -1,61 +0,0 @@
from typing import List, Optional, Tuple, Union
import torch
from vllm.forward_context import get_forward_context
from .base import MSEventKey
from .context import (get_multistream_layer_context,
reset_multistream_layer_context,
set_multistream_layer_context)
from .metadata import MultiStreamMetadata
class MultiStreamPreTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(
self,
intput_tensors: List[torch.Tensor],
):
attn_metadata = get_forward_context().attn_metadata
if self.multistream_metadata is None or attn_metadata is None:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
# TODO add attn_metadata management
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
attn_metadata, intput_tensors)
if do_ms:
set_multistream_layer_context(
self.multistream_metadata.start_layer,
self.multistream_metadata, attn_metadata)
else:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
class MultiStreamPostTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(self,
input_tensors: Union[List[Tuple[torch.Tensor]],
List[torch.Tensor],
List[List[torch.Tensor]]],
wait_layer_index: Optional[int] = None):
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
return input_tensors
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
)
if layer_index >= 0:
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
self.multistream_metadata.try_wait_event(
true_wait_layer,
self.multistream_metadata.ms_config.num_micro_batches - 1,
MSEventKey.FFN_AR_FINISH)
reset_multistream_layer_context()
return self.multistream_metadata.merge_micro_batches(input_tensors)

View File

@@ -1,182 +0,0 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from .base import MSAttentionMetadataSplitConfig, MSEventKey
def split_micro_batches_tensors(input_tensors,
split_index: int,
keys: Optional[List[str]] = None):
if isinstance(input_tensors, list):
micro_batches = []
for tensor in input_tensors:
if tensor is None:
micro_batches.append([None, None])
else:
micro_batches.append(
[tensor[:split_index], tensor[split_index:]])
return micro_batches
elif isinstance(input_tensors, torch.Tensor):
return [input_tensors[:split_index], input_tensors[split_index:]]
elif input_tensors is None:
return [None, None]
elif isinstance(input_tensors, Dict):
assert keys is not None
micro_batches_pre = {}
for key in keys:
micro_batches_pre[key] = input_tensors[key][:split_index]
micro_batches_post = {}
for key in keys:
micro_batches_post[key] = input_tensors[key][split_index:]
return [micro_batches_pre, micro_batches_post]
else:
raise NotImplementedError
@dataclass
class MultiStreamStepMetadata:
comm_stream: torch.npu.Stream = None
before_comm_event: torch.npu.Event = None
after_comm_event: torch.npu.Event = None
@dataclass
class MultiStreamConfig:
"""Controls the behavior of multi-stream models."""
min_total_tokens_to_split: int = 256
min_prefill_tokens_to_split: int = 64
num_micro_batches: int = 2
imbalance_ratio: float = 0.1
class MultiStreamMetadata:
# direct stream
calculate_stream = None
# delay stream
communicate_stream = None
# events
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
# multi-stream-flag
enable_multi_stream: bool = False
def __init__(
self,
calculate_stream: torch.npu.Stream,
communicate_stream: torch.npu.Stream,
start_layer: int,
end_layer: int,
event_keys: List[MSEventKey],
multistream_config: Optional[MultiStreamConfig],
causal_lm: bool = True,
):
self.calculate_stream = calculate_stream
self.communicate_stream = communicate_stream
self.start_layer = start_layer
self.end_layer = end_layer
self.ms_config = multistream_config
self.causal_lm = causal_lm
self._build_events(event_keys)
self._build_ms_split_config()
def _build_events(self, event_keys):
if self.ms_config is not None:
for i in range(self.start_layer - 1, self.end_layer):
self.ms_events[i] = {}
for j in range(self.ms_config.num_micro_batches):
self.ms_events[i][j] = {}
for key in event_keys:
self.ms_events[i][j][key] = torch.npu.Event()
def _build_ms_split_config(self):
if self.ms_config is not None:
self.ms_split_config = MSAttentionMetadataSplitConfig(
num_micro_batches=self.ms_config.num_micro_batches,
min_total_tokens_to_split=self.ms_config.
min_total_tokens_to_split,
min_prefill_tokens_to_split=self.ms_config.
min_prefill_tokens_to_split,
)
def try_wait_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].wait()
def try_record_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].record()
def split_micro_batch(
self,
attn_metadata: "AscendMLAMetadata",
intput_tensors: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors_keys: Optional[List[str]] = None,
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
List[torch.Tensor], List[List[torch.Tensor]]], Union[
IntermediateTensors, List[IntermediateTensors]]]:
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
self.ms_split_config)
if len(attn_metadata_list) == 1:
return False, attn_metadata_list[
0], intput_tensors, intermediate_tensors
split_index = attn_metadata_list[0].slot_mapping.shape[0]
input_tensors = split_micro_batches_tensors(intput_tensors,
split_index)
if intermediate_tensors is not None:
inter_tensors_list = split_micro_batches_tensors(
intermediate_tensors.tensors, split_index,
intermediate_tensors_keys)
intermediate_tensors = [
IntermediateTensors(inter_tensors)
for inter_tensors in inter_tensors_list
]
return True, attn_metadata_list, input_tensors, intermediate_tensors
def merge_micro_batches(
self, input_tensors: Union[List[torch.Tensor],
List[List[torch.Tensor]]]
) -> List[torch.Tensor]:
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
return input_tensors
batch: List[Optional[torch.Tensor]] = []
for tensors in input_tensors:
if tensors is None or tensors[0] is None:
batch.append(None)
else:
batch.append(torch.cat(tensors, dim=0))
return batch
def make_multistream_metadata_ds(
start_layer: int,
end_layer: int,
causal_lm: bool = True,
multistream_config: Optional[MultiStreamConfig] = None,
):
if multistream_config is None:
return None
event_keylist = [
MSEventKey.ATTN_COM_FINISH,
MSEventKey.ATTN_AR_FINISH,
MSEventKey.FFN_COM_FINISH,
MSEventKey.FFN_AR_FINISH,
MSEventKey.MOE_BEFORE_COMM,
MSEventKey.MOE_AFTER_COMM,
MSEventKey.MOE_SE_COMM_FINISH,
MSEventKey.MOE_SE_COMP_FINISH,
MSEventKey.MOE_GATE_FINISH,
]
return MultiStreamMetadata(
calculate_stream=torch.npu.current_stream(),
communicate_stream=torch.npu.Stream(),
start_layer=start_layer,
end_layer=end_layer,
multistream_config=multistream_config,
event_keys=event_keylist,
causal_lm=causal_lm,
)

View File

@@ -1,247 +0,0 @@
from copy import deepcopy
from typing import Any, List, Optional
import numpy as np
import torch
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from .base import MSAttentionMetadataSplitConfig
def compute_split_seq_index(
query_lens: Optional[list[int]],
attn_state: AscendAttentionState,
num_tokens: int,
imbalance_ratio: float = 0.1,
) -> list[int]:
if attn_state != AscendAttentionState.DecodeOnly:
assert query_lens is not None
total_tokens = sum(query_lens)
# the first index in last split
tokens, split_index = 0, 0
for value in query_lens:
tokens += value
split_index += 1
if tokens >= total_tokens // 2:
# check the current split index
if abs(tokens -
total_tokens // 2) < total_tokens * imbalance_ratio:
return [tokens, split_index]
# check the previous split index
elif abs(tokens - total_tokens // 2 -
value) < total_tokens * imbalance_ratio:
return [tokens - value, split_index - 1]
# fail to split if it is imbalanced
# TODO: split tokens in seq
else:
return [0, 0]
else:
tokens = num_tokens // 2
return [tokens, tokens]
return [0, 0]
def split_attn_tensor_type(
input_tensor: torch.Tensor,
index: int,
) -> List[torch.Tensor]:
return [input_tensor[:index], input_tensor[index:]]
def split_attn_int_type(
var: int,
index: int,
) -> List[torch.Tensor]:
return [min(var, index), max(var - index, 0)]
def model_input_split_v1_mla_attn(
attn_metadata,
_metadata_cls,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> List[Any]:
assert 0 < ms_split_config.num_micro_batches < 3
if attn_metadata is None:
return [attn_metadata]
[token_index,
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
attn_metadata.attn_state,
attn_metadata.num_decode_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
attn_metadata.query_lens):
return [attn_metadata]
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
dtype=int)
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
if attn_metadata.num_prefills > 0:
prefill_query_start_loc = np.zeros(
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
np.cumsum(attn_metadata.prefill.query_lens,
out=prefill_query_start_loc[1:])
# split attn metadata
[slot_mapping_pre,
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
token_index)
[num_decodes_pre,
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
seq_index)
[num_decode_tokens_pre, num_decode_tokens_post
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
[num_prefills_pre, num_prefills_post
] = split_attn_int_type(attn_metadata.num_prefills,
max(0, seq_index - attn_metadata.num_decodes))
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
query_start_loc_pre = query_start_loc_post = None
if attn_metadata.query_start_loc is not None:
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
query_start_loc_post = deepcopy(
attn_metadata.query_start_loc[seq_index:]
) - attn_metadata.query_start_loc[seq_index]
[block_table_pre,
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
seq_index)
assert attn_metadata.attn_mask is not None
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
# the attn_mla kernel in torch npu only accept 128*128 attn mask
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = attn_metadata.attn_state
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
# should be none in decode only state
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
else:
# chunked prefill
if num_prefills_pre > 0:
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
seq_lens_pre)].contiguous()
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
else:
attn_state_pre = AscendAttentionState.DecodeOnly
attn_mask_pre = None
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAPrefillMetadata)
if num_prefills_pre > 0:
# split metadata.prefill
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
attn_metadata.prefill.input_positions,
token_index - attn_metadata.num_decode_tokens)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
seq_index - attn_metadata.num_decodes)
[prefill_query_lens_pre, prefill_query_lens_post
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
seq_index - attn_metadata.num_decodes)
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
seq_index
+
1 -
attn_metadata
.
num_decodes]
prefill_query_start_loc_post = deepcopy(
attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes:]
) - attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes]
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
context_len_post = seq_lens_post
prefill_max_query_len_pre = max(prefill_query_lens_pre)
prefill_max_query_len_post = max(prefill_query_lens_post)
prefill_pre = AscendMLAPrefillMetadata(
attn_mask=attn_mask_pre,
query_lens=prefill_query_lens_pre,
seq_lens=seq_lens_pre,
query_start_loc=prefill_query_start_loc_pre,
input_positions=input_positions_pre,
context_lens=context_len_pre,
block_table=block_tables_pre,
max_query_len=prefill_max_query_len_pre,
max_seq_lens=context_len_pre.max().item(),
)
prefill_post = AscendMLAPrefillMetadata(
attn_mask=attn_mask_post,
query_lens=prefill_query_lens_post,
seq_lens=seq_lens_post,
query_start_loc=prefill_query_start_loc_post,
input_positions=input_positions_post,
context_lens=context_len_post,
block_table=block_tables_post,
max_query_len=prefill_max_query_len_post,
max_seq_lens=context_len_post.max().item(),
)
decode_pre = attn_metadata.decode
decode_post = None
else:
# prefill is None, split metadata.decode
[input_positions_pre, input_positions_post
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
token_index)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.decode.block_table,
seq_index)
[decode_seq_lens_pre,
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
decode_pre = AscendMLADecodeMetadata(
input_positions=input_positions_pre,
block_table=block_tables_pre,
seq_lens=decode_seq_lens_pre,
max_seq_lens=max(decode_seq_lens_pre),
seq_lens_list=decode_seq_lens_pre.tolist(),
)
decode_post = AscendMLADecodeMetadata(
input_positions=input_positions_post,
block_table=block_tables_post,
seq_lens=decode_seq_lens_post,
max_seq_lens=max(decode_seq_lens_post),
seq_lens_list=decode_seq_lens_post.tolist(),
)
prefill_pre = None
prefill_post = attn_metadata.prefill
# construct metadata
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
attention_metadata_pre = _metadata_cls(
num_actual_tokens=token_index,
num_input_tokens=token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_pre,
seq_lens=seq_lens_pre,
query_start_loc=query_start_loc_pre,
block_tables=block_table_pre,
num_decodes=num_decodes_pre,
num_prefills=num_prefills_pre,
num_decode_tokens=num_decode_tokens_pre,
attn_state=attn_state_pre,
attn_mask=attn_mask_pre,
prefill=prefill_pre,
decode=decode_pre,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
num_input_tokens=attn_metadata.num_input_tokens - token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_post,
seq_lens=seq_lens_post,
query_start_loc=query_start_loc_post,
block_tables=block_table_post,
num_decodes=num_decodes_post,
num_prefills=num_prefills_post,
num_decode_tokens=num_decode_tokens_post,
attn_mask=attn_mask_post,
attn_state=attn_state_post,
prefill=prefill_post,
decode=decode_post,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
return [attention_metadata_pre, attention_metadata_post]

View File

@@ -122,10 +122,11 @@ class MtpProposer(Proposer):
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
(
num_tokens,
num_tokens_across_dp,
with_prefill,
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
@@ -429,10 +430,9 @@ class MtpProposer(Proposer):
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(
num_input_tokens, self.runner.with_prefill, False)
(num_input_tokens, num_tokens_across_dp,
with_prefill) = self.runner._sync_metadata_across_dp(
num_input_tokens, self.runner.with_prefill)
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp

View File

@@ -264,8 +264,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
attn_state=attn_state)
return attn_metadata

View File

@@ -20,9 +20,6 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
npu_stream_switch, npu_wait_tensor)
@@ -141,7 +138,6 @@ class AscendMLATorchairMetadata:
decode: Optional[AscendMLATorchairDecodeMetadata] = None
prefill: Optional[AscendMLATorchairPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
@@ -152,17 +148,6 @@ class AscendMLATorchairMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendMLATorchairMetadata"]:
"""Split metadata for multi-stream with AscendMLATorchairMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLATorchairMetadata,
)
M = TypeVar("M", bound=AscendMLATorchairMetadata)
@@ -576,7 +561,6 @@ class AscendMLATorchairMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
@@ -1072,15 +1056,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output,
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla)
def forward(
self,
@@ -1248,13 +1225,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
prefill_k_c_normed,
prefill_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
o_proj_input[num_decode_tokens:] = output_prefill
if has_decode:
@@ -1269,16 +1239,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
decode_k_nope,
decode_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode
else:
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
@@ -1288,16 +1252,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
return output_padded

View File

@@ -110,30 +110,28 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.enable_shared_expert_dp:
# Padding is not required for shared_expert_dp cases in eager mode.
return num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo
return maybe_padded_num_tokens, None, with_prefill
return num_tokens, None, with_prefill
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
dtype=torch.int32,
device="npu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
num_tokens_across_dp[-1] = int(with_prefill)
dist.all_reduce(num_tokens_across_dp,
group=get_dp_group().device_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]
with_prefill = bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-1]
if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
@@ -146,7 +144,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
else:
maybe_padded_num_tokens = num_tokens
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
def _build_dummy_attn_metadata(
self,

View File

@@ -21,8 +21,6 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import is_enable_nz
from vllm_ascend.worker.npu_input_batch import InputBatch
@@ -141,7 +139,6 @@ class AscendSFATorchairMetadata:
decode: Optional[AscendSFATorchairDecodeMetadata] = None
prefill: Optional[AscendSFATorchairPrefillMetadata] = None
enable_dbo_across_dp: bool = False
is_prefill: bool = False
is_decode: bool = False
@@ -154,17 +151,6 @@ class AscendSFATorchairMetadata:
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFATorchairMetadata"]:
"""Split metadata for multi-stream with AscendSFATorchairMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendSFATorchairMetadata,
)
M = TypeVar("M", bound=AscendSFATorchairMetadata)
@@ -616,7 +602,6 @@ class AscendSFATorchairMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_prefill=is_prefill,
is_decode=is_decode)

View File

@@ -758,13 +758,13 @@ def get_default_buffer_config() -> dict:
def calculate_dp_buffer_size() -> int:
"""
formula of dp buffer size:
dp_size + 2 (flags: with_prefill and enable_dbo)
dp_size + 1 (flags: with_prefill)
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
dp_size = vllm_config.parallel_config.data_parallel_size
int32_size = torch.iinfo(torch.int32).bits // 8
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
dp_buffer_size = math.ceil((dp_size + 1) * int32_size / (1024 * 1024))
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)

View File

@@ -121,7 +121,6 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
@@ -859,8 +858,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
# our case, we still need to sync the other two flags as well. So we need to
# include them in the all_reduce operation, and more over, we CANNOT skip it
@@ -868,17 +867,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
],
dtype=torch.int32,
device="npu")
flags_tensor = torch.tensor(
[int(with_prefill), int(not enable_dbo)],
flags_tensor = torch.tensor([int(with_prefill)],
dtype=torch.int32,
device="npu")
@@ -887,12 +885,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
# Unpack the results
num_tokens_across_dp = packed_tensor[:-2]
synced_flags = packed_tensor[-2:]
num_tokens_across_dp = packed_tensor[:-1]
synced_flags = packed_tensor[-1:]
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
global_with_prefill = bool(synced_flags[0])
global_enable_dbo = not bool(synced_flags[1])
# Create a tensor for num_tokens_after_padding
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
@@ -900,28 +897,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
attn_state: AscendAttentionState,
num_tokens: int) -> bool:
# do the checks for dp + dbo
if attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
return False
# considering the case that one dp rank may enable dbo while others may not
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
return False
# TODO: remove it if token-level microbatch is enabled
[token_index,
seq_index] = compute_split_seq_index(query_lens, attn_state,
num_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
query_lens) or num_tokens < 256:
return False
return True
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill
def get_model(self) -> nn.Module:
# get raw model out of the aclgraph wrapper.
@@ -1430,16 +1406,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
# Get info across DP ranks.
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
# Otherwise, it's just max_tokens_across_dp_cpu
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill, enable_dbo)
(maybe_padded_num_tokens, num_tokens_across_dp,
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill)
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
# We should consider removing maybe_padded_num_tokens later
@@ -1707,7 +1680,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
@@ -2603,8 +2575,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
(num_tokens, num_tokens_across_dp,
with_prefill) = self._sync_metadata_across_dp(num_tokens,
with_prefill)
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)