support FULL graph mode for GQA (#3970)
### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -180,6 +180,7 @@ jobs:
|
||||
if: ${{ inputs.type == 'full' }}
|
||||
run: |
|
||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
|
||||
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
||||
pytest -sv tests/e2e/multicard/test_external_launcher.py
|
||||
|
||||
@@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
@@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
compilation_config={"cudagraph_mode":
|
||||
"FULL_DECODE_ONLY"}) as runner:
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
|
||||
}) as runner:
|
||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||
sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
vllm_fullgraph_outputs_list = []
|
||||
for output in vllm_fullgraph_outputs:
|
||||
vllm_fullgraph_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_fullgraph_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_fullgraph_outputs",
|
||||
)
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
model = "Qwen/Qwen3-30B-A3B"
|
||||
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
|
||||
with VllmRunner(model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
|
||||
}) as runner:
|
||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||
sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams,
|
||||
|
||||
graph_mode_str = "PIECEWISE"
|
||||
if graph_mode == CUDAGraphMode.FULL:
|
||||
graph_mode_str = "FULL"
|
||||
graph_mode_str = "FULL_DECODE_ONLY"
|
||||
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
@@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=2000,
|
||||
compilation_config=CompilationConfig(
|
||||
cudagraph_mode=graph_mode_str),
|
||||
cudagraph_mode=graph_mode_str,
|
||||
cudagraph_capture_sizes=[12],
|
||||
),
|
||||
additional_config={"ascend_scheduler_config": {
|
||||
"enabled": False
|
||||
}}) as spec_llm:
|
||||
|
||||
@@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
mock_reshape_cache,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
@@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_prefill_cache_hit(self,
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
@@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
@@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
def test_forward_decode_only(self, mock_paged_attention,
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_decode_only(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_get_forward_context):
|
||||
mock_paged_attention):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch.npu.graph_task_group_end')
|
||||
@patch('torch.npu.graph_task_group_begin')
|
||||
@patch('torch.npu.ExternalEvent')
|
||||
@patch('torch_npu.npu.current_stream')
|
||||
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
|
||||
def test_paged_attention_with_existing_workspace(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_get_graph_params,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention,
|
||||
mock_graph_begin,
|
||||
mock_graph_end,
|
||||
mock_external_event_class,
|
||||
mock_current_stream,
|
||||
mock_weak_ref_tensors,
|
||||
):
|
||||
graph_params = MagicMock()
|
||||
attn_metadata = MagicMock()
|
||||
num_tokens = 10
|
||||
|
||||
graph_params.workspaces = {num_tokens: 10}
|
||||
graph_params.events = {num_tokens: []}
|
||||
graph_params.attn_params = {num_tokens: []}
|
||||
graph_params.handles = {num_tokens: []}
|
||||
|
||||
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
|
||||
key_cache = MagicMock()
|
||||
value_cache = MagicMock()
|
||||
num_kv_heads = 4
|
||||
num_heads = 8
|
||||
scale = 0.1
|
||||
output = torch.randn(2, 5, 8)
|
||||
|
||||
self_obj = MagicMock()
|
||||
self_obj.key_cache = key_cache
|
||||
self_obj.value_cache = value_cache
|
||||
self_obj.num_kv_heads = num_kv_heads
|
||||
self_obj.num_heads = num_heads
|
||||
self_obj.scale = scale
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_current_stream.return_value = mock_stream
|
||||
mock_event_instance = MagicMock()
|
||||
mock_external_event_class.return_value = mock_event_instance
|
||||
|
||||
mock_handle = MagicMock()
|
||||
mock_graph_end.return_value = mock_handle
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
self.assertEqual(workspace, 10)
|
||||
|
||||
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
|
||||
|
||||
# 2. Handle graph capturing mode
|
||||
stream = mock_current_stream()
|
||||
event = mock_external_event_class()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self_obj.key_cache),
|
||||
weak_ref_tensors(self_obj.value_cache),
|
||||
self_obj.num_kv_heads,
|
||||
self_obj.num_heads,
|
||||
self_obj.scale,
|
||||
weak_ref_tensors(attn_metadata.block_tables),
|
||||
attn_metadata.seq_lens,
|
||||
output,
|
||||
))
|
||||
|
||||
mock_event_instance.wait.assert_called_once_with(mock_stream)
|
||||
mock_event_instance.reset.assert_called_once_with(mock_stream)
|
||||
self.assertEqual(len(graph_params.events[num_tokens]), 1)
|
||||
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
|
||||
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
||||
mock_get_graph_params.return_value = graph_params
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
|
||||
mock_fused_infer_attention_score,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty(10, 8, 64)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10] * 10)
|
||||
@@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_decode_only_swa_seq_len_mismatch(
|
||||
self, mock_fused_infer_attention_score, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache, mock_get_forward_context):
|
||||
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
||||
mock_paged_attention, mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64),
|
||||
torch.ones(10, 8, 64))
|
||||
|
||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
@@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||
mock_npu_reshape_and_cache, mock_is_310p):
|
||||
mock_npu_reshape_and_cache, mock_is_310p,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass when head_size is 192"""
|
||||
|
||||
self.impl.head_size = 192
|
||||
@@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_normal_v1_situation(self,
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in normal V1 situation"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
@@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
def test_forward_310p_device(self, mock_is_310p,
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_npu_format_cast):
|
||||
@@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
|
||||
@@ -195,6 +195,7 @@ class AscendMetadataForDecode:
|
||||
class AscendMetadata:
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
fia_attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
@@ -215,6 +216,7 @@ class AscendMetadata:
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_list: List[int] = None # type: ignore
|
||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||
query_start_loc_list: List[int] = None # type: ignore
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_lens: torch.Tensor = None
|
||||
@@ -241,7 +243,8 @@ class AscendMetadata:
|
||||
class AscendAttentionMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
AttentionCGSupport.ALWAYS
|
||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
@@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens_pcp_padded]
|
||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
fia_attn_mask = common_attn_metadata.fia_attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
@@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
@@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder:
|
||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
fia_attn_mask=fia_attn_mask,
|
||||
attn_state=attn_state,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
@@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.dcp_group = get_dcp_group(
|
||||
).device_group if self.dcp_size > 1 else None
|
||||
|
||||
def full_graph_attention(self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
num_tokens=0):
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
|
||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||
query = query[:num_tokens]
|
||||
graph_params = get_graph_params()
|
||||
query_start_loc = attn_metadata.query_start_loc_list
|
||||
# Prepare tensors for attention output
|
||||
# TODO: Refactor this to step-level instead of layer-level
|
||||
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.fia_attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
sparse_mode=3,
|
||||
scale=self.scale,
|
||||
)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
|
||||
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||
weak_ref_tensors(softmax_lse)))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.fia_attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
)
|
||||
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output, num_tokens
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -692,70 +805,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self.key_cache),
|
||||
weak_ref_tensors(self.value_cache),
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
weak_ref_tensors(output),
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
@@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||
@@ -1481,47 +1539,51 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_decode_tokens:attn_metadata.
|
||||
num_actual_tokens_pcp_padded])
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
intermediate_output = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)[0]
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
intermediate_output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
intermediate_output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
intermediate_output = self._forward_decode_only(
|
||||
query, attn_metadata, output)
|
||||
# Normal V1 situation.
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if not forward_context.capturing:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
intermediate_output = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)[0]
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
intermediate_output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
intermediate_output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
intermediate_output = self._forward_decode_only(
|
||||
query, attn_metadata, output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
intermediate_output = self._forward_v1_style(
|
||||
query, attn_metadata, output)
|
||||
else:
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
intermediate_output = self._forward_v1_style(
|
||||
query, attn_metadata, output)
|
||||
|
||||
intermediate_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||
|
||||
return output
|
||||
|
||||
@@ -1278,8 +1278,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
@@ -1779,8 +1778,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
|
||||
seq_len, num_heads, self.scale, self.num_kv_heads,
|
||||
**common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
||||
dtype=q_nope.dtype,
|
||||
|
||||
@@ -88,6 +88,8 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
fia_attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
@@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
(query, key_cache, value, block_tables, attn_mask, block_size,
|
||||
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
|
||||
attn_output, softmax_lse) = param
|
||||
|
||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
||||
# calculate max workspace in capturing. So additional get_workspace is added
|
||||
# here to avoid such bugs.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||
query_start_loc = forward_context.attn_metadata[
|
||||
key].query_start_loc_list
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
key=key_cache,
|
||||
value=value,
|
||||
block_table=block_tables,
|
||||
atten_mask=attn_mask,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths_kv=seq_lens,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale=scale,
|
||||
sparse_mode=3,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
|
||||
@@ -233,7 +233,8 @@ class NPUPlatform(Platform):
|
||||
"vllm.mla_forward"
|
||||
])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
@@ -270,7 +271,8 @@ class NPUPlatform(Platform):
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
|
||||
@@ -21,6 +21,7 @@ import math
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
@@ -31,7 +32,6 @@ from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import numpy as np
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
@@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
||||
self.attn_mask = None
|
||||
self.fia_attn_mask = None
|
||||
self.attn_state = None
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
@@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
mrope_pos_ptr = 0
|
||||
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||
@@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
position=positions_cpu,
|
||||
attn_state=attn_state)
|
||||
self.fia_attn_mask = self._make_fia_attention_mask()
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
self.with_prefill = with_prefill
|
||||
@@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
fia_attn_mask=self.fia_attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
@@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
|
||||
self.device).to(torch.int32)
|
||||
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
|
||||
|
||||
self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens)
|
||||
self.query_start_loc_cpu[1:num_reqs +
|
||||
1] = torch.Tensor(cu_num_tokens)
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
assigned_mask_dim = 2048
|
||||
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
|
||||
assigned_mask_dim),
|
||||
diagonal=1).to(torch.int8).to(
|
||||
self.device)
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
|
||||
@@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
fia_attn_mask=self.fia_attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
max_query_len=max_query_len,
|
||||
@@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
graph_support = None
|
||||
if hasattr(builder, 'aclgraph_support'):
|
||||
graph_support = builder.aclgraph_support.value
|
||||
builder_aclgraph = builder.aclgraph_support
|
||||
else:
|
||||
graph_support = builder.cudagraph_support.value
|
||||
builder_aclgraph = builder.cudagraph_support
|
||||
if graph_support < min_ag_support.value:
|
||||
min_ag_support = builder.aclgraph_support
|
||||
min_ag_support = builder_aclgraph
|
||||
min_ag_builder_name = builder.__class__.__name__
|
||||
|
||||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
Reference in New Issue
Block a user