[main] remove dbo code (#3712)
### What this PR does / why we need it?
Remove codes of dbo.
Currently, vLLM has supported dbo with pr:
https://github.com/vllm-project/vllm/pull/23693.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
# enable dual-batch overlap for vllm ascend
|
||||
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = ["The president of the United States is"] * 41
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic",
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
max_model_len=4096,
|
||||
trust_remote_code=True,
|
||||
enable_expert_parallel=True,
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": False
|
||||
},
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True
|
||||
},
|
||||
})
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||
# objects that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Add a buffer to wait for profiler in the background process
|
||||
# (in case MP is on) to finish writing profiling output.
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -623,11 +623,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch("torch.npu.stream")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||
mock_get_multistream_comm_context, mock_npu_stream,
|
||||
mock_get_forward_context):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
@@ -651,8 +648,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
mock_npu_fused_infer_attention_score.return_value = [
|
||||
torch.randn(B, N, self.impl.kv_lora_rank), None
|
||||
]
|
||||
mock_get_multistream_comm_context.return_value = None
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
@@ -660,18 +655,3 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
|
||||
self.impl.enable_kv_nz = False
|
||||
attn_metadata.attn_state = None
|
||||
mock_return_value = MagicMock()
|
||||
mock_get_multistream_comm_context.return_value = mock_return_value
|
||||
mock_return_value.before_comm_event = MagicMock()
|
||||
mock_return_value.comm_stream = MagicMock()
|
||||
mock_npu_stream.return_value = MagicMock()
|
||||
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
|
||||
MSEventKey)
|
||||
|
||||
|
||||
class Testbase(TestBase):
|
||||
|
||||
def test_ms_event_key(self):
|
||||
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
|
||||
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
|
||||
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
|
||||
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
|
||||
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
|
||||
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
|
||||
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
|
||||
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
|
||||
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
|
||||
|
||||
def test_ms_attention_metadata_split_config_default(self):
|
||||
config = MSAttentionMetadataSplitConfig()
|
||||
self.assertEqual(config.num_micro_batches, 2)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 256)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 64)
|
||||
|
||||
def test_ms_attention_metadata_split_config_custom(self):
|
||||
config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=4,
|
||||
min_total_tokens_to_split=512,
|
||||
min_prefill_tokens_to_split=128)
|
||||
self.assertEqual(config.num_micro_batches, 4)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 512)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 128)
|
||||
@@ -1,47 +0,0 @@
|
||||
import pytest
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.multistream.decorator import set_multistream_support
|
||||
|
||||
|
||||
class Context:
|
||||
|
||||
def __init__(self, attn_metadata=None):
|
||||
self.attn_metadata = attn_metadata
|
||||
|
||||
|
||||
class TestDecorator(PytestBase):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'layer_context, microbatch_context, expected_metadata', [
|
||||
((-1, None, None), -1, {
|
||||
"original": True
|
||||
}),
|
||||
((-1, None, None), 0, {
|
||||
"original": True
|
||||
}),
|
||||
((0, None, None), -1, {
|
||||
"original": True
|
||||
}),
|
||||
((0, None, [{
|
||||
"new": True
|
||||
}]), 0, {
|
||||
"new": True
|
||||
}),
|
||||
])
|
||||
def test_decorator(self, mocker: MockFixture, layer_context,
|
||||
microbatch_context, expected_metadata):
|
||||
|
||||
def context_func():
|
||||
return Context(attn_metadata={"original": True})
|
||||
|
||||
mocker.patch(
|
||||
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
|
||||
return_value=layer_context)
|
||||
mocker.patch(
|
||||
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
|
||||
return_value=microbatch_context)
|
||||
|
||||
context = set_multistream_support()(context_func)()
|
||||
assert context.attn_metadata == expected_metadata
|
||||
@@ -1,198 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
|
||||
MultiStreamPreTransformerLayer)
|
||||
from vllm_ascend.multistream.metadata import MultiStreamMetadata
|
||||
|
||||
|
||||
# === fixture: mock tensor input ===
|
||||
@pytest.fixture
|
||||
def input_tensors():
|
||||
return [torch.randn(2, 128), torch.randn(2, 128)]
|
||||
|
||||
|
||||
# === mock get_forward_context ===
|
||||
class DummyContext:
|
||||
|
||||
def __init__(self, attn_metadata):
|
||||
self.attn_metadata = attn_metadata
|
||||
|
||||
|
||||
class TestMultiStreamPreTransformerLayer(PytestBase):
|
||||
|
||||
# === test when multistream_metadata is None ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx,
|
||||
input_tensors):
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta")
|
||||
layer = MultiStreamPreTransformerLayer(multistream_metadata=None)
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out == "dummy_meta"
|
||||
assert input_out == input_tensors
|
||||
mock_set_ctx.assert_called_once_with(-1, None, None)
|
||||
|
||||
# === test when attn_metadata is None ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx,
|
||||
input_tensors):
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata=None)
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
layer = MultiStreamPreTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out is None
|
||||
assert input_out == input_tensors
|
||||
mock_set_ctx.assert_called_once_with(-1, None, None)
|
||||
|
||||
# === test when do_ms=False (no split needed) ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
|
||||
dummy_attn = "original_attn"
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
|
||||
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
dummy_metadata.split_micro_batch.return_value = (False, "same_attn",
|
||||
input_tensors, None)
|
||||
|
||||
layer = MultiStreamPreTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out == "same_attn"
|
||||
assert input_out == input_tensors
|
||||
mock_set_ctx.assert_called_once_with(-1, None, None)
|
||||
|
||||
# === test when do_ms=True (split occurred) ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
|
||||
dummy_attn = "original_attn"
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
|
||||
|
||||
split_inputs = [[t[:1], t[1:]] for t in input_tensors]
|
||||
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
dummy_metadata.start_layer = 2
|
||||
dummy_metadata.split_micro_batch.return_value = (True,
|
||||
["attn1", "attn2"],
|
||||
split_inputs, None)
|
||||
|
||||
layer = MultiStreamPreTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out == ["attn1", "attn2"]
|
||||
assert input_out == split_inputs
|
||||
mock_set_ctx.assert_called_once_with(2, dummy_metadata,
|
||||
["attn1", "attn2"])
|
||||
|
||||
|
||||
class TestMultiStreamPostTransformerLayer(PytestBase):
|
||||
|
||||
def test_post_forward_metadata_none(self, input_tensors):
|
||||
layer = MultiStreamPostTransformerLayer(multistream_metadata=None)
|
||||
output = layer.forward(input_tensors)
|
||||
assert output == input_tensors
|
||||
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
dummy_metadata.ms_config = None
|
||||
layer = MultiStreamPostTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
output = layer.forward(input_tensors)
|
||||
assert output == input_tensors
|
||||
|
||||
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
|
||||
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
|
||||
def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx,
|
||||
input_tensors):
|
||||
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
|
||||
calculate_stream=MagicMock(),
|
||||
communicate_stream=MagicMock(),
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
event_keys=[],
|
||||
multistream_config=None,
|
||||
)
|
||||
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
|
||||
dummy_metadata.ms_config.num_micro_batches = 4
|
||||
dummy_metadata.end_layer = 10
|
||||
|
||||
mock_get_ctx.return_value = (
|
||||
5, # layer_index
|
||||
dummy_metadata, # ms_metadata
|
||||
"dummy_attn_metadata" # ms_attn_metadata
|
||||
)
|
||||
|
||||
dummy_metadata.merge_micro_batches.return_value = "merged_result"
|
||||
|
||||
layer = MultiStreamPostTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
output = layer.forward(input_tensors)
|
||||
|
||||
# check wait_event
|
||||
dummy_metadata.try_wait_event.assert_called_once_with(
|
||||
9, # end_layer - 1
|
||||
3, # num_micro_batches - 1
|
||||
MSEventKey.FFN_AR_FINISH)
|
||||
mock_reset_ctx.assert_called_once()
|
||||
assert output == "merged_result"
|
||||
|
||||
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
|
||||
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
|
||||
def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx,
|
||||
mock_get_ctx, input_tensors):
|
||||
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
|
||||
calculate_stream=MagicMock(),
|
||||
communicate_stream=MagicMock(),
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
event_keys=[],
|
||||
multistream_config=None,
|
||||
)
|
||||
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
|
||||
dummy_metadata.ms_config.num_micro_batches = 4
|
||||
dummy_metadata.end_layer = 10
|
||||
|
||||
mock_get_ctx.return_value = (
|
||||
3, # layer_index
|
||||
dummy_metadata,
|
||||
"dummy_attn_metadata")
|
||||
|
||||
dummy_metadata.merge_micro_batches.return_value = "merged_result"
|
||||
|
||||
layer = MultiStreamPostTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
output = layer.forward(input_tensors, wait_layer_index=7)
|
||||
|
||||
dummy_metadata.try_wait_event.assert_called_once_with(
|
||||
7, 3, MSEventKey.FFN_AR_FINISH)
|
||||
mock_reset_ctx.assert_called_once()
|
||||
assert output == "merged_result"
|
||||
@@ -1,246 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||
MultiStreamMetadata,
|
||||
MultiStreamStepMetadata,
|
||||
split_micro_batches_tensors)
|
||||
|
||||
|
||||
class TestMetaData(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
|
||||
self.test_tensors = torch.randn(100, 1024)
|
||||
self.test_tensors_dict = {
|
||||
'query': torch.randn(100, 1024),
|
||||
'key': torch.randn(100, 1024),
|
||||
'value': torch.randn(100, 1024)
|
||||
}
|
||||
self.split_index = 50
|
||||
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
|
||||
self.metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
def test_split_micro_batches_tensors(self):
|
||||
test_tensors_list_res = split_micro_batches_tensors(
|
||||
self.test_tensors_list, self.split_index)
|
||||
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
|
||||
self.split_index)
|
||||
keys = ['query', 'key', 'value']
|
||||
test_tensors_dict_res = split_micro_batches_tensors(
|
||||
self.test_tensors_dict, self.split_index, keys)
|
||||
for i in range(3):
|
||||
self.assertEqual(len(test_tensors_list_res[i][0]),
|
||||
self.split_index)
|
||||
|
||||
self.assertEqual(
|
||||
len(test_tensors_list_res[i][0]) +
|
||||
len(test_tensors_list_res[i][1]), 100)
|
||||
|
||||
self.assertEqual(len(test_tensors_res[0]), self.split_index)
|
||||
self.assertEqual(
|
||||
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
|
||||
|
||||
for key in keys:
|
||||
self.assertEqual(len(test_tensors_dict_res[0][key]),
|
||||
self.split_index)
|
||||
self.assertEqual(
|
||||
len(test_tensors_dict_res[0][key]) +
|
||||
len(test_tensors_dict_res[1][key]), 100)
|
||||
|
||||
def test_default_init_multistream_step_metadata(self):
|
||||
metadata = MultiStreamStepMetadata()
|
||||
self.assertIsNone(metadata.comm_stream)
|
||||
self.assertIsNone(metadata.before_comm_event)
|
||||
self.assertIsNone(metadata.after_comm_event)
|
||||
|
||||
def test_custom_init_multistream_step_metadata(self):
|
||||
mockStream = MagicMock(spec=torch.npu.Stream)
|
||||
mockEvent1 = MagicMock(spec=torch.npu.Event)
|
||||
mockEvent2 = MagicMock(spec=torch.npu.Event)
|
||||
|
||||
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
|
||||
self.assertEqual(metadata.comm_stream, mockStream)
|
||||
self.assertEqual(metadata.before_comm_event, mockEvent1)
|
||||
self.assertEqual(metadata.after_comm_event, mockEvent2)
|
||||
|
||||
def test_default_init_multistream_config(self):
|
||||
config = MultiStreamConfig()
|
||||
self.assertEqual(config.min_total_tokens_to_split, 256)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 64)
|
||||
self.assertEqual(config.num_micro_batches, 2)
|
||||
self.assertEqual(config.imbalance_ratio, 0.1)
|
||||
|
||||
def test_custom_init_multistream_config(self):
|
||||
config = MultiStreamConfig(512, 128, 1, 0.2)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 512)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 128)
|
||||
self.assertEqual(config.num_micro_batches, 1)
|
||||
self.assertEqual(config.imbalance_ratio, 0.2)
|
||||
|
||||
def test_init_multistream_metadata(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
|
||||
event_keys = [MagicMock()]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
|
||||
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
self.assertEqual(metadata.calculate_stream, mock_stream)
|
||||
self.assertEqual(metadata.communicate_stream, mock_stream)
|
||||
self.assertEqual(metadata.start_layer, 1)
|
||||
self.assertEqual(metadata.end_layer, 3)
|
||||
self.assertEqual(metadata.ms_config, multistream_config)
|
||||
self.assertTrue(metadata.causal_lm)
|
||||
|
||||
def test_build_events(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MultiStreamConfig(
|
||||
num_micro_batches=2,
|
||||
min_total_tokens_to_split=256,
|
||||
min_prefill_tokens_to_split=64)
|
||||
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
expected_events = {
|
||||
0: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
},
|
||||
1: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
},
|
||||
2: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
}
|
||||
}
|
||||
self.assertEqual(metadata.ms_events, expected_events)
|
||||
|
||||
def test_build_ms_split_config(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
multistream_config.num_micro_batches = 2
|
||||
multistream_config.min_total_tokens_to_split = 256
|
||||
multistream_config.min_prefill_tokens_to_split = 64
|
||||
|
||||
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
self.assertIsNotNone(metadata.ms_split_config)
|
||||
self.assertEqual(metadata.ms_split_config.num_micro_batches,
|
||||
multistream_config.num_micro_batches)
|
||||
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
|
||||
multistream_config.min_total_tokens_to_split)
|
||||
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
|
||||
multistream_config.min_prefill_tokens_to_split)
|
||||
|
||||
def test_try_wait_event(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
metadata.try_wait_event(layer_index=1,
|
||||
micro_batch_index=0,
|
||||
event_key=event_keys[0])
|
||||
mock_event.wait.assert_called_once()
|
||||
|
||||
def test_try_record_event(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
metadata.try_record_event(layer_index=1,
|
||||
micro_batch_index=0,
|
||||
event_key=event_keys[0])
|
||||
mock_event.record.assert_called_once()
|
||||
|
||||
def test_merge_batches_none_input(self):
|
||||
input_tensors = None
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_merge_batches_single_tensor_input(self):
|
||||
input_tensors = [torch.tensor([1, 2, 3])]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
|
||||
|
||||
def test_merge_batches_list_of_tensors_input(self):
|
||||
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result, input_tensors)
|
||||
|
||||
def test_merge_batches_nested_list_input(self):
|
||||
input_tensors = [[torch.tensor([1, 2]),
|
||||
torch.tensor([3, 4])],
|
||||
[torch.tensor([5, 6]),
|
||||
torch.tensor([7, 8])]]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
|
||||
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))
|
||||
@@ -1,147 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
|
||||
model_input_split_v1_mla_attn,
|
||||
split_attn_int_type,
|
||||
split_attn_tensor_type)
|
||||
|
||||
|
||||
class TestMsSplit(TestBase):
|
||||
|
||||
def test_decode_only(self):
|
||||
result = compute_split_seq_index(
|
||||
query_lens=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [5, 5])
|
||||
|
||||
def test_perfect_balance(self):
|
||||
query_lens = [2, 3, 5]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [5, 2])
|
||||
|
||||
def test_imbalance(self):
|
||||
query_lens = [1, 2, 3, 4]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_query_lens_none(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
compute_split_seq_index(
|
||||
query_lens=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
|
||||
def test_empty_query_lens(self):
|
||||
query_lens: list[int] = []
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_single_query_len(self):
|
||||
query_lens = [10]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_split_attn_tensor_type_middle(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 3
|
||||
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_start(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 0
|
||||
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_end(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 5
|
||||
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_empty_tensor(self):
|
||||
input_tensor = torch.tensor([])
|
||||
index = 0
|
||||
expected_result = [torch.tensor([]), torch.tensor([])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_int_type_index_greater_than_var(self):
|
||||
var = 5
|
||||
index = 10
|
||||
expected_result = [5, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_equal_to_var(self):
|
||||
var = 5
|
||||
index = 5
|
||||
expected_result = [5, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_less_than_var(self):
|
||||
var = 10
|
||||
index = 5
|
||||
expected_result = [5, 5]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_zero(self):
|
||||
var = 10
|
||||
index = 0
|
||||
expected_result = [0, 10]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_var_zero(self):
|
||||
var = 0
|
||||
index = 5
|
||||
expected_result = [0, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_both_zero(self):
|
||||
var = 0
|
||||
index = 0
|
||||
expected_result = [0, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_v1_mla_attn_input_none(self):
|
||||
attn_metadata = None
|
||||
ascendMLAPrefillMetadata = MagicMock()
|
||||
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
|
||||
result = model_input_split_v1_mla_attn(attn_metadata,
|
||||
ascendMLAPrefillMetadata,
|
||||
ms_split_config)
|
||||
self.assertEqual(result, [None])
|
||||
@@ -210,9 +210,6 @@ class AscendMetadata:
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
# *************************** Other Properties *************************** #
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
prefill: Optional[AscendMetadataForPrefill] = None
|
||||
|
||||
decode_meta: Optional[AscendMetadataForDecode] = None
|
||||
@@ -371,7 +368,6 @@ class AscendAttentionMetadataBuilder:
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
prefill=prefill_metadata,
|
||||
|
||||
@@ -36,9 +36,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
@@ -184,7 +181,6 @@ class AscendMLAMetadata:
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -195,17 +191,6 @@ class AscendMLAMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendMLAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
@@ -538,7 +523,6 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -1158,13 +1142,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
else:
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj(attn_output)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
|
||||
@@ -1423,12 +1401,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
|
||||
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
|
||||
kv_cache[0].shape[1], attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
@@ -1445,18 +1418,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
|
||||
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
|
||||
prefill_preprocess_res.value, kv_cache, attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[
|
||||
num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
|
||||
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
if current_ms_metadata is None:
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
@@ -1465,16 +1430,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
@@ -1719,18 +1675,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
||||
seq_mask_pcp[:, i])
|
||||
attn_output = attn_out_g
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj(attn_output)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
|
||||
# TODO use update op to replace this
|
||||
|
||||
def _update_out_and_lse(
|
||||
self,
|
||||
out: torch.Tensor,
|
||||
|
||||
@@ -17,11 +17,8 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -138,7 +135,6 @@ class AscendSFAMetadata:
|
||||
|
||||
decode: Optional[AscendSFADecodeMetadata] = None
|
||||
prefill: Optional[AscendSFAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -149,17 +145,6 @@ class AscendSFAMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
@@ -434,7 +419,6 @@ class AscendSFAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -91,8 +91,6 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
is_only_prefill: bool = False
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
@@ -82,9 +82,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
||||
),
|
||||
# Whether to enable DBO feature for deepseek model.
|
||||
"VLLM_ASCEND_ENABLE_DBO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
||||
# Whether to enable the model execute time observe profile. Disable it when
|
||||
# running vllm ascend in production environment.
|
||||
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MSEventKey(Enum):
|
||||
ATTN_COM_FINISH = 0
|
||||
ATTN_AR_FINISH = 1
|
||||
FFN_COM_FINISH = 2
|
||||
FFN_AR_FINISH = 3
|
||||
# events for MOE dispatch and combine
|
||||
MOE_BEFORE_COMM = 4
|
||||
MOE_AFTER_COMM = 5
|
||||
# events for shared expert
|
||||
MOE_SE_COMM_FINISH = 6
|
||||
MOE_SE_COMP_FINISH = 7
|
||||
MOE_GATE_FINISH = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MSAttentionMetadataSplitConfig:
|
||||
"""
|
||||
micro batch split config for split attention metadata
|
||||
"""
|
||||
# micro batch num
|
||||
num_micro_batches: int = 2
|
||||
# split micro batches only when total tokens >= min_total_tokens_to_split
|
||||
min_total_tokens_to_split: int = 256
|
||||
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
@@ -1,67 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_ms_comm_context: Any = None
|
||||
_cur_micro_batch_num: int = -1
|
||||
_ms_layer_index_context: int = -1
|
||||
_ms_metadata_context: Any = None
|
||||
_ms_attn_metadata_context: Any = None
|
||||
|
||||
|
||||
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
|
||||
attn_metadata: Any):
|
||||
"""
|
||||
set multistream layer context before transformer layers
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = start_layer
|
||||
_ms_metadata_context = ms_metadata
|
||||
_ms_attn_metadata_context = attn_metadata
|
||||
|
||||
|
||||
def reset_multistream_layer_context():
|
||||
"""
|
||||
reset multistream layer context
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = -1
|
||||
_ms_metadata_context = None
|
||||
_ms_attn_metadata_context = None
|
||||
|
||||
|
||||
def get_multistream_layer_context():
|
||||
"""
|
||||
get multistream layer context
|
||||
"""
|
||||
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
|
||||
|
||||
def advance_step_multistream_layer_context():
|
||||
"""
|
||||
advance multistream layer index context
|
||||
"""
|
||||
global _ms_layer_index_context
|
||||
_ms_layer_index_context += 1
|
||||
|
||||
|
||||
def get_multistream_comm_context() -> Any:
|
||||
"""Get the current comm forward context."""
|
||||
return _ms_comm_context
|
||||
|
||||
|
||||
def get_multistream_microbatch_context() -> int:
|
||||
return _cur_micro_batch_num
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_multistream_context(context: Any, micro_batch_num: int):
|
||||
"""A context manager that stores the current comm forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _ms_comm_context, _cur_micro_batch_num
|
||||
_ms_comm_context = context
|
||||
_cur_micro_batch_num = micro_batch_num
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_ms_comm_context = None
|
||||
_cur_micro_batch_num = -1
|
||||
@@ -1,22 +0,0 @@
|
||||
from .context import (get_multistream_layer_context,
|
||||
get_multistream_microbatch_context)
|
||||
|
||||
|
||||
# vllm v1 use get_forward_context to get the attn_metadata,
|
||||
# we can use this decorator to update the attn metadata
|
||||
def set_multistream_support():
|
||||
|
||||
def decorator(func):
|
||||
|
||||
def wrapper():
|
||||
context = func()
|
||||
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
micro_batch_num = get_multistream_microbatch_context()
|
||||
if layer_index != -1 and micro_batch_num != -1:
|
||||
context.attn_metadata = attn_metadata[micro_batch_num]
|
||||
return context
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,61 +0,0 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from .base import MSEventKey
|
||||
from .context import (get_multistream_layer_context,
|
||||
reset_multistream_layer_context,
|
||||
set_multistream_layer_context)
|
||||
from .metadata import MultiStreamMetadata
|
||||
|
||||
|
||||
class MultiStreamPreTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
intput_tensors: List[torch.Tensor],
|
||||
):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if self.multistream_metadata is None or attn_metadata is None:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
# TODO add attn_metadata management
|
||||
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
|
||||
attn_metadata, intput_tensors)
|
||||
if do_ms:
|
||||
set_multistream_layer_context(
|
||||
self.multistream_metadata.start_layer,
|
||||
self.multistream_metadata, attn_metadata)
|
||||
else:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
|
||||
|
||||
class MultiStreamPostTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(self,
|
||||
input_tensors: Union[List[Tuple[torch.Tensor]],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]]],
|
||||
wait_layer_index: Optional[int] = None):
|
||||
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
|
||||
return input_tensors
|
||||
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
if layer_index >= 0:
|
||||
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
|
||||
self.multistream_metadata.try_wait_event(
|
||||
true_wait_layer,
|
||||
self.multistream_metadata.ms_config.num_micro_batches - 1,
|
||||
MSEventKey.FFN_AR_FINISH)
|
||||
reset_multistream_layer_context()
|
||||
return self.multistream_metadata.merge_micro_batches(input_tensors)
|
||||
@@ -1,182 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig, MSEventKey
|
||||
|
||||
|
||||
def split_micro_batches_tensors(input_tensors,
|
||||
split_index: int,
|
||||
keys: Optional[List[str]] = None):
|
||||
if isinstance(input_tensors, list):
|
||||
micro_batches = []
|
||||
for tensor in input_tensors:
|
||||
if tensor is None:
|
||||
micro_batches.append([None, None])
|
||||
else:
|
||||
micro_batches.append(
|
||||
[tensor[:split_index], tensor[split_index:]])
|
||||
return micro_batches
|
||||
elif isinstance(input_tensors, torch.Tensor):
|
||||
return [input_tensors[:split_index], input_tensors[split_index:]]
|
||||
elif input_tensors is None:
|
||||
return [None, None]
|
||||
elif isinstance(input_tensors, Dict):
|
||||
assert keys is not None
|
||||
micro_batches_pre = {}
|
||||
for key in keys:
|
||||
micro_batches_pre[key] = input_tensors[key][:split_index]
|
||||
micro_batches_post = {}
|
||||
for key in keys:
|
||||
micro_batches_post[key] = input_tensors[key][split_index:]
|
||||
return [micro_batches_pre, micro_batches_post]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamStepMetadata:
|
||||
comm_stream: torch.npu.Stream = None
|
||||
before_comm_event: torch.npu.Event = None
|
||||
after_comm_event: torch.npu.Event = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamConfig:
|
||||
"""Controls the behavior of multi-stream models."""
|
||||
min_total_tokens_to_split: int = 256
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
num_micro_batches: int = 2
|
||||
imbalance_ratio: float = 0.1
|
||||
|
||||
|
||||
class MultiStreamMetadata:
|
||||
# direct stream
|
||||
calculate_stream = None
|
||||
# delay stream
|
||||
communicate_stream = None
|
||||
# events
|
||||
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
|
||||
# multi-stream-flag
|
||||
enable_multi_stream: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculate_stream: torch.npu.Stream,
|
||||
communicate_stream: torch.npu.Stream,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
event_keys: List[MSEventKey],
|
||||
multistream_config: Optional[MultiStreamConfig],
|
||||
causal_lm: bool = True,
|
||||
):
|
||||
self.calculate_stream = calculate_stream
|
||||
self.communicate_stream = communicate_stream
|
||||
self.start_layer = start_layer
|
||||
self.end_layer = end_layer
|
||||
self.ms_config = multistream_config
|
||||
self.causal_lm = causal_lm
|
||||
self._build_events(event_keys)
|
||||
self._build_ms_split_config()
|
||||
|
||||
def _build_events(self, event_keys):
|
||||
if self.ms_config is not None:
|
||||
for i in range(self.start_layer - 1, self.end_layer):
|
||||
self.ms_events[i] = {}
|
||||
for j in range(self.ms_config.num_micro_batches):
|
||||
self.ms_events[i][j] = {}
|
||||
for key in event_keys:
|
||||
self.ms_events[i][j][key] = torch.npu.Event()
|
||||
|
||||
def _build_ms_split_config(self):
|
||||
if self.ms_config is not None:
|
||||
self.ms_split_config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=self.ms_config.num_micro_batches,
|
||||
min_total_tokens_to_split=self.ms_config.
|
||||
min_total_tokens_to_split,
|
||||
min_prefill_tokens_to_split=self.ms_config.
|
||||
min_prefill_tokens_to_split,
|
||||
)
|
||||
|
||||
def try_wait_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].wait()
|
||||
|
||||
def try_record_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].record()
|
||||
|
||||
def split_micro_batch(
|
||||
self,
|
||||
attn_metadata: "AscendMLAMetadata",
|
||||
intput_tensors: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
intermediate_tensors_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
|
||||
List[torch.Tensor], List[List[torch.Tensor]]], Union[
|
||||
IntermediateTensors, List[IntermediateTensors]]]:
|
||||
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
|
||||
self.ms_split_config)
|
||||
if len(attn_metadata_list) == 1:
|
||||
return False, attn_metadata_list[
|
||||
0], intput_tensors, intermediate_tensors
|
||||
split_index = attn_metadata_list[0].slot_mapping.shape[0]
|
||||
input_tensors = split_micro_batches_tensors(intput_tensors,
|
||||
split_index)
|
||||
if intermediate_tensors is not None:
|
||||
inter_tensors_list = split_micro_batches_tensors(
|
||||
intermediate_tensors.tensors, split_index,
|
||||
intermediate_tensors_keys)
|
||||
intermediate_tensors = [
|
||||
IntermediateTensors(inter_tensors)
|
||||
for inter_tensors in inter_tensors_list
|
||||
]
|
||||
return True, attn_metadata_list, input_tensors, intermediate_tensors
|
||||
|
||||
def merge_micro_batches(
|
||||
self, input_tensors: Union[List[torch.Tensor],
|
||||
List[List[torch.Tensor]]]
|
||||
) -> List[torch.Tensor]:
|
||||
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
|
||||
return input_tensors
|
||||
batch: List[Optional[torch.Tensor]] = []
|
||||
for tensors in input_tensors:
|
||||
if tensors is None or tensors[0] is None:
|
||||
batch.append(None)
|
||||
else:
|
||||
batch.append(torch.cat(tensors, dim=0))
|
||||
return batch
|
||||
|
||||
|
||||
def make_multistream_metadata_ds(
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
causal_lm: bool = True,
|
||||
multistream_config: Optional[MultiStreamConfig] = None,
|
||||
):
|
||||
if multistream_config is None:
|
||||
return None
|
||||
event_keylist = [
|
||||
MSEventKey.ATTN_COM_FINISH,
|
||||
MSEventKey.ATTN_AR_FINISH,
|
||||
MSEventKey.FFN_COM_FINISH,
|
||||
MSEventKey.FFN_AR_FINISH,
|
||||
MSEventKey.MOE_BEFORE_COMM,
|
||||
MSEventKey.MOE_AFTER_COMM,
|
||||
MSEventKey.MOE_SE_COMM_FINISH,
|
||||
MSEventKey.MOE_SE_COMP_FINISH,
|
||||
MSEventKey.MOE_GATE_FINISH,
|
||||
]
|
||||
return MultiStreamMetadata(
|
||||
calculate_stream=torch.npu.current_stream(),
|
||||
communicate_stream=torch.npu.Stream(),
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
multistream_config=multistream_config,
|
||||
event_keys=event_keylist,
|
||||
causal_lm=causal_lm,
|
||||
)
|
||||
@@ -1,247 +0,0 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig
|
||||
|
||||
|
||||
def compute_split_seq_index(
|
||||
query_lens: Optional[list[int]],
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int,
|
||||
imbalance_ratio: float = 0.1,
|
||||
) -> list[int]:
|
||||
if attn_state != AscendAttentionState.DecodeOnly:
|
||||
assert query_lens is not None
|
||||
total_tokens = sum(query_lens)
|
||||
# the first index in last split
|
||||
tokens, split_index = 0, 0
|
||||
for value in query_lens:
|
||||
tokens += value
|
||||
split_index += 1
|
||||
if tokens >= total_tokens // 2:
|
||||
# check the current split index
|
||||
if abs(tokens -
|
||||
total_tokens // 2) < total_tokens * imbalance_ratio:
|
||||
return [tokens, split_index]
|
||||
# check the previous split index
|
||||
elif abs(tokens - total_tokens // 2 -
|
||||
value) < total_tokens * imbalance_ratio:
|
||||
return [tokens - value, split_index - 1]
|
||||
# fail to split if it is imbalanced
|
||||
# TODO: split tokens in seq
|
||||
else:
|
||||
return [0, 0]
|
||||
else:
|
||||
tokens = num_tokens // 2
|
||||
return [tokens, tokens]
|
||||
return [0, 0]
|
||||
|
||||
|
||||
def split_attn_tensor_type(
|
||||
input_tensor: torch.Tensor,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [input_tensor[:index], input_tensor[index:]]
|
||||
|
||||
|
||||
def split_attn_int_type(
|
||||
var: int,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [min(var, index), max(var - index, 0)]
|
||||
|
||||
|
||||
def model_input_split_v1_mla_attn(
|
||||
attn_metadata,
|
||||
_metadata_cls,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> List[Any]:
|
||||
assert 0 < ms_split_config.num_micro_batches < 3
|
||||
if attn_metadata is None:
|
||||
return [attn_metadata]
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
||||
attn_metadata.attn_state,
|
||||
attn_metadata.num_decode_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
attn_metadata.query_lens):
|
||||
return [attn_metadata]
|
||||
|
||||
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
||||
dtype=int)
|
||||
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
||||
if attn_metadata.num_prefills > 0:
|
||||
prefill_query_start_loc = np.zeros(
|
||||
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
||||
np.cumsum(attn_metadata.prefill.query_lens,
|
||||
out=prefill_query_start_loc[1:])
|
||||
|
||||
# split attn metadata
|
||||
[slot_mapping_pre,
|
||||
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
||||
token_index)
|
||||
[num_decodes_pre,
|
||||
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
||||
seq_index)
|
||||
[num_decode_tokens_pre, num_decode_tokens_post
|
||||
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
||||
[num_prefills_pre, num_prefills_post
|
||||
] = split_attn_int_type(attn_metadata.num_prefills,
|
||||
max(0, seq_index - attn_metadata.num_decodes))
|
||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
|
||||
query_start_loc_pre = query_start_loc_post = None
|
||||
if attn_metadata.query_start_loc is not None:
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
[block_table_pre,
|
||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||
seq_index)
|
||||
assert attn_metadata.attn_mask is not None
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
# should be none in decode only state
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
# chunked prefill
|
||||
if num_prefills_pre > 0:
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
||||
seq_lens_pre)].contiguous()
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
else:
|
||||
attn_state_pre = AscendAttentionState.DecodeOnly
|
||||
attn_mask_pre = None
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAPrefillMetadata)
|
||||
if num_prefills_pre > 0:
|
||||
# split metadata.prefill
|
||||
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
||||
attn_metadata.prefill.input_positions,
|
||||
token_index - attn_metadata.num_decode_tokens)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
[prefill_query_lens_pre, prefill_query_lens_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
||||
seq_index
|
||||
+
|
||||
1 -
|
||||
attn_metadata
|
||||
.
|
||||
num_decodes]
|
||||
prefill_query_start_loc_post = deepcopy(
|
||||
attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes:]
|
||||
) - attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes]
|
||||
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
||||
context_len_post = seq_lens_post
|
||||
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
||||
prefill_max_query_len_post = max(prefill_query_lens_post)
|
||||
prefill_pre = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_pre,
|
||||
query_lens=prefill_query_lens_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=prefill_query_start_loc_pre,
|
||||
input_positions=input_positions_pre,
|
||||
context_lens=context_len_pre,
|
||||
block_table=block_tables_pre,
|
||||
max_query_len=prefill_max_query_len_pre,
|
||||
max_seq_lens=context_len_pre.max().item(),
|
||||
)
|
||||
prefill_post = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_post,
|
||||
query_lens=prefill_query_lens_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=prefill_query_start_loc_post,
|
||||
input_positions=input_positions_post,
|
||||
context_lens=context_len_post,
|
||||
block_table=block_tables_post,
|
||||
max_query_len=prefill_max_query_len_post,
|
||||
max_seq_lens=context_len_post.max().item(),
|
||||
)
|
||||
decode_pre = attn_metadata.decode
|
||||
decode_post = None
|
||||
else:
|
||||
# prefill is None, split metadata.decode
|
||||
[input_positions_pre, input_positions_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
||||
token_index)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
||||
seq_index)
|
||||
[decode_seq_lens_pre,
|
||||
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
decode_pre = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_pre,
|
||||
block_table=block_tables_pre,
|
||||
seq_lens=decode_seq_lens_pre,
|
||||
max_seq_lens=max(decode_seq_lens_pre),
|
||||
seq_lens_list=decode_seq_lens_pre.tolist(),
|
||||
)
|
||||
decode_post = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_post,
|
||||
block_table=block_tables_post,
|
||||
seq_lens=decode_seq_lens_post,
|
||||
max_seq_lens=max(decode_seq_lens_post),
|
||||
seq_lens_list=decode_seq_lens_post.tolist(),
|
||||
)
|
||||
prefill_pre = None
|
||||
prefill_post = attn_metadata.prefill
|
||||
# construct metadata
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
attention_metadata_pre = _metadata_cls(
|
||||
num_actual_tokens=token_index,
|
||||
num_input_tokens=token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=query_start_loc_pre,
|
||||
block_tables=block_table_pre,
|
||||
num_decodes=num_decodes_pre,
|
||||
num_prefills=num_prefills_pre,
|
||||
num_decode_tokens=num_decode_tokens_pre,
|
||||
attn_state=attn_state_pre,
|
||||
attn_mask=attn_mask_pre,
|
||||
prefill=prefill_pre,
|
||||
decode=decode_pre,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
attention_metadata_post = _metadata_cls(
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=query_start_loc_post,
|
||||
block_tables=block_table_post,
|
||||
num_decodes=num_decodes_post,
|
||||
num_prefills=num_prefills_post,
|
||||
num_decode_tokens=num_decode_tokens_post,
|
||||
attn_mask=attn_mask_post,
|
||||
attn_state=attn_state_post,
|
||||
prefill=prefill_post,
|
||||
decode=decode_post,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
return [attention_metadata_pre, attention_metadata_post]
|
||||
@@ -122,10 +122,11 @@ class MtpProposer(Proposer):
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None) -> None:
|
||||
if not self.torchair_graph_enabled:
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._sync_metadata_across_dp(num_tokens,
|
||||
with_prefill, False)
|
||||
(
|
||||
num_tokens,
|
||||
num_tokens_across_dp,
|
||||
with_prefill,
|
||||
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
|
||||
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
@@ -429,10 +430,9 @@ class MtpProposer(Proposer):
|
||||
|
||||
if not self.torchair_graph_enabled:
|
||||
# torch mode need to update num_tokens_across_dp
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._sync_metadata_across_dp(
|
||||
num_input_tokens, self.runner.with_prefill, False)
|
||||
(num_input_tokens, num_tokens_across_dp,
|
||||
with_prefill) = self.runner._sync_metadata_across_dp(
|
||||
num_input_tokens, self.runner.with_prefill)
|
||||
else:
|
||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
|
||||
@@ -264,8 +264,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -20,9 +20,6 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
@@ -141,7 +138,6 @@ class AscendMLATorchairMetadata:
|
||||
|
||||
decode: Optional[AscendMLATorchairDecodeMetadata] = None
|
||||
prefill: Optional[AscendMLATorchairPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -152,17 +148,6 @@ class AscendMLATorchairMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendMLATorchairMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendMLATorchairMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLATorchairMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLATorchairMetadata)
|
||||
|
||||
@@ -576,7 +561,6 @@ class AscendMLATorchairMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
|
||||
@@ -1072,15 +1056,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj_and_o_proj(attn_output,
|
||||
enable_multistream_mla)
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self._v_up_proj_and_o_proj(attn_output)
|
||||
|
||||
return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1248,13 +1225,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
prefill_k_c_normed,
|
||||
prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
else:
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
|
||||
if has_decode:
|
||||
@@ -1269,16 +1239,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
decode_k_nope,
|
||||
decode_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
else:
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
if current_ms_metadata is None:
|
||||
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
@@ -1288,16 +1252,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
o_proj_input,
|
||||
is_prefill=True,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=True,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
|
||||
del o_proj_input
|
||||
return output_padded
|
||||
|
||||
@@ -110,30 +110,28 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.enable_shared_expert_dp:
|
||||
# Padding is not required for shared_expert_dp cases in eager mode.
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
return maybe_padded_num_tokens, None, with_prefill
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
num_tokens_across_dp[-1] = int(with_prefill)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
with_prefill = bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-1]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
@@ -146,7 +144,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
else:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
|
||||
@@ -21,8 +21,6 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import is_enable_nz
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
@@ -141,7 +139,6 @@ class AscendSFATorchairMetadata:
|
||||
|
||||
decode: Optional[AscendSFATorchairDecodeMetadata] = None
|
||||
prefill: Optional[AscendSFATorchairPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
is_prefill: bool = False
|
||||
is_decode: bool = False
|
||||
|
||||
@@ -154,17 +151,6 @@ class AscendSFATorchairMetadata:
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFATorchairMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFATorchairMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendSFATorchairMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFATorchairMetadata)
|
||||
|
||||
@@ -616,7 +602,6 @@ class AscendSFATorchairMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
is_prefill=is_prefill,
|
||||
is_decode=is_decode)
|
||||
|
||||
|
||||
@@ -758,13 +758,13 @@ def get_default_buffer_config() -> dict:
|
||||
def calculate_dp_buffer_size() -> int:
|
||||
"""
|
||||
formula of dp buffer size:
|
||||
dp_size + 2 (flags: with_prefill and enable_dbo)
|
||||
dp_size + 1 (flags: with_prefill)
|
||||
"""
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
int32_size = torch.iinfo(torch.int32).bits // 8
|
||||
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
|
||||
dp_buffer_size = math.ceil((dp_size + 1) * int32_size / (1024 * 1024))
|
||||
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
|
||||
|
||||
|
||||
|
||||
@@ -121,7 +121,6 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
@@ -859,8 +858,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
|
||||
# our case, we still need to sync the other two flags as well. So we need to
|
||||
# include them in the all_reduce operation, and more over, we CANNOT skip it
|
||||
@@ -868,17 +867,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
|
||||
# immediately once the other two flags are no longer needed.
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
|
||||
# Sync num_tokens, with_prefill across dp ranks
|
||||
num_tokens_tensor = torch.tensor([
|
||||
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
|
||||
flags_tensor = torch.tensor(
|
||||
[int(with_prefill), int(not enable_dbo)],
|
||||
flags_tensor = torch.tensor([int(with_prefill)],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
|
||||
@@ -887,12 +885,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
|
||||
|
||||
# Unpack the results
|
||||
num_tokens_across_dp = packed_tensor[:-2]
|
||||
synced_flags = packed_tensor[-2:]
|
||||
num_tokens_across_dp = packed_tensor[:-1]
|
||||
synced_flags = packed_tensor[-1:]
|
||||
|
||||
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
||||
global_with_prefill = bool(synced_flags[0])
|
||||
global_enable_dbo = not bool(synced_flags[1])
|
||||
|
||||
# Create a tensor for num_tokens_after_padding
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
||||
@@ -900,28 +897,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
||||
|
||||
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int) -> bool:
|
||||
# do the checks for dp + dbo
|
||||
if attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
return False
|
||||
# considering the case that one dp rank may enable dbo while others may not
|
||||
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
return False
|
||||
# TODO: remove it if token-level microbatch is enabled
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(query_lens, attn_state,
|
||||
num_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
query_lens) or num_tokens < 256:
|
||||
return False
|
||||
return True
|
||||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
# get raw model out of the aclgraph wrapper.
|
||||
@@ -1430,16 +1406,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
]
|
||||
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
# Get info across DP ranks.
|
||||
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
|
||||
# Otherwise, it's just max_tokens_across_dp_cpu
|
||||
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
|
||||
with_prefill, enable_dbo)
|
||||
(maybe_padded_num_tokens, num_tokens_across_dp,
|
||||
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
|
||||
with_prefill)
|
||||
|
||||
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
|
||||
# We should consider removing maybe_padded_num_tokens later
|
||||
@@ -1707,7 +1680,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
enable_dbo_across_dp=enable_dbo,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
graph_pad_size=self.graph_pad_size,
|
||||
@@ -2603,8 +2575,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
||||
(num_tokens, num_tokens_across_dp,
|
||||
with_prefill) = self._sync_metadata_across_dp(num_tokens,
|
||||
with_prefill)
|
||||
|
||||
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user