[perf]: support dual-batch overlap(dbo) for deepseek (#941)
### What this PR does / why we need it? Based on the design of dual-batch overlap proposed by Deepseek team and also the implementation of fused moe in VLLM project, we implement the multi-stream(also known as dual-batch) overlap for deepseek+mla on Ascend NPU. We split the input batch of model into two microbatches and then overlap the comp/comm ops in attention and moe layers using two streams to improve the performance. Our approach can be easily extended when adding dispatch/combine communications for moe layer. Compared with the previously proposed [draft](https://github.com/vllm-project/vllm-ascend/pull/842), we use one stream for computation ops and the other for communication ops, separately. In out opinions, it is beneficial for arranging the order of executing different ops and thus avoiding the contention of computation/communication resources. ref: [overlap for llama](https://github.com/vllm-project/vllm/pull/15787/files) ref: [dbo in sglang](https://github.com/sgl-project/sglang/pull/4068/files#diff-b4937569fc71f6ad215181b633b2f89c7183a2b4ac39e41fc22635599a9be7de) ### Does this PR introduce _any_ user-facing change? Adding an env variable "VLLM_ENABLE_DBO". Users can enable dbo by setting "VLLM_ASCEND_ENABLE_DBO=1" See /examples/offline_dualbatch_overlap_npu.py for more info. ### How was this patch tested? This patch can be tested with vllm-0.9.0 using its online service with benchmark tests. We have decoupled the func of dbo from vllm and it should be able to run without any modification to the code of vllm(some modifications is better to implement in vllm though). Any advice/discussion is welcome. ### Performance Benchmark We have ran the benchmark_serving script of vllm to test the performance after using dual-batch overlap. `python -m vllm.entrypoints.openai.api_server \ --model=DeepSeek-R1-W8A8 \ --trust-remote-code \ --distributed-executor-backend=mp \ -tp=16 \ --port 8006 \ --max-num-seqs 390 \ --max-model-len 32768 \ --max-num-batched-tokens 65536 \ --block-size 128 \ --compilation_config 0 \ --gpu-memory-utilization 0.90 \ --disable-log-requests \ --additional-config '{"expert_tensor_parallel_size":1,"enable_inter_dp_scheduling":true,"init_torchair_graph_batch_sizes":true,"trace_recompiles":true,"ascend_scheduler_config":{},"enable_graph_mode":false}'` and run benchmark with the parameters of : `--dataset-name random --random-input-len 4096 --random-output-len 1 --num-prompts 200 --max-concurrency 8 --request-rate 5 --metric-percentiles 90` 1. test with the version using allgather+allreduce in Ascend 910B (tp16 ep16 + deepseek r1 w8a8) 2. test with the version using alltoall: prefill qps: 0.90 -> 1.01 Mean TTFT:8226->7432ms The overlap approach when using alltoall communication can be further optimized by overlapping micro-batch1's moe comp with micro-batch2's dispatch a2a comm --------- Signed-off-by: zhuohuan <zxdu1997@gmail.com>
This commit is contained in:
51
examples/offline_dualbatch_overlap_npu.py
Normal file
51
examples/offline_dualbatch_overlap_npu.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# enable dual-batch overlap for vllm ascend
|
||||||
|
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
|
||||||
|
os.environ["VLLM_USE_V1"] = "1"
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = ["The president of the United States is"] * 41
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic",
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
max_model_len=4096,
|
||||||
|
trust_remote_code=True,
|
||||||
|
additional_config={
|
||||||
|
"torchair_graph_config": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"ascend_scheduler_config": {
|
||||||
|
"enabled": True
|
||||||
|
},
|
||||||
|
"expert_tensor_parallel_size": 1
|
||||||
|
})
|
||||||
|
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||||
|
# objects that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
print("-" * 50)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# Add a buffer to wait for profiler in the background process
|
||||||
|
# (in case MP is on) to finish writing profiling output.
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -81,3 +81,17 @@ def test_models_distributed_topk() -> None:
|
|||||||
distributed_executor_backend="mp",
|
distributed_executor_backend="mp",
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate(example_prompts, sampling_params)
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
|
||||||
|
def test_models_distributed_DeepSeek_dbo():
|
||||||
|
example_prompts = ["The president of the United States is"] * 41
|
||||||
|
dtype = "half"
|
||||||
|
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||||
|
with VllmRunner(
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite",
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ from vllm.model_executor.layers.linear import (LinearBase,
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -117,6 +120,7 @@ class AscendMLAMetadata:
|
|||||||
|
|
||||||
with_prefill_across_dp: bool = False
|
with_prefill_across_dp: bool = False
|
||||||
|
|
||||||
|
query_lens: Optional[list[int]] = None
|
||||||
# The dimension of the attention heads
|
# The dimension of the attention heads
|
||||||
head_dim: Optional[int] = None
|
head_dim: Optional[int] = None
|
||||||
attn_mask: torch.Tensor = None
|
attn_mask: torch.Tensor = None
|
||||||
@@ -135,6 +139,17 @@ class AscendMLAMetadata:
|
|||||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||||
# f"received {self.head_dim}.")
|
# f"received {self.head_dim}.")
|
||||||
|
|
||||||
|
def split_metadata_for_multistream(
|
||||||
|
self,
|
||||||
|
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||||
|
) -> list["AscendMLAMetadata"]:
|
||||||
|
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
||||||
|
return model_input_split_v1_mla_attn(
|
||||||
|
ms_split_config=ms_split_config,
|
||||||
|
attn_metadata=self,
|
||||||
|
_metadata_cls=AscendMLAMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||||
|
|
||||||
@@ -386,6 +401,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
query_lens=query_lens.tolist(),
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
head_dim=self.runner.model_config.get_head_size(),
|
head_dim=self.runner.model_config.get_head_size(),
|
||||||
num_decodes=self._num_decodes,
|
num_decodes=self._num_decodes,
|
||||||
@@ -585,7 +601,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
)
|
)
|
||||||
attn_output = attn_output.reshape(
|
attn_output = attn_output.reshape(
|
||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
return self.o_proj(attn_output)[0]
|
|
||||||
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
|
if current_ms_metadata is None:
|
||||||
|
return self.o_proj(attn_output)[0]
|
||||||
|
else:
|
||||||
|
current_ms_metadata.before_comm_event.record()
|
||||||
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
|
current_ms_metadata.before_comm_event.wait()
|
||||||
|
return self.o_proj(attn_output)[0]
|
||||||
|
|
||||||
def exec_kv(
|
def exec_kv(
|
||||||
self,
|
self,
|
||||||
@@ -685,7 +709,14 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||||
mla_vheadsize=self.kv_lora_rank,
|
mla_vheadsize=self.kv_lora_rank,
|
||||||
out=attn_output)
|
out=attn_output)
|
||||||
return self._v_up_proj_and_o_proj(attn_output)
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
|
if current_ms_metadata is None:
|
||||||
|
return self._v_up_proj_and_o_proj(attn_output)
|
||||||
|
else:
|
||||||
|
current_ms_metadata.before_comm_event.record()
|
||||||
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
|
current_ms_metadata.before_comm_event.wait()
|
||||||
|
return self._v_up_proj_and_o_proj(attn_output)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -811,16 +842,38 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
key_cache=kv_cache,
|
key_cache=kv_cache,
|
||||||
slot_indices=attn_metadata.slot_mapping.flatten())
|
slot_indices=attn_metadata.slot_mapping.flatten())
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
output[num_decode_tokens:] = self._forward_prefill(
|
# FIX: aicore move should be also placed on the comm stream in dbo,
|
||||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
# otherwise it may affect the accuracy
|
||||||
attn_metadata)
|
# TODO: use an elegant way to overlap
|
||||||
|
output_prefill = self._forward_prefill(prefill_q,
|
||||||
|
prefill_k_c_normed,
|
||||||
|
prefill_k_pe, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
|
if current_ms_metadata is not None:
|
||||||
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
|
output[num_decode_tokens:] = output_prefill
|
||||||
|
current_ms_metadata.after_comm_event.record()
|
||||||
|
else:
|
||||||
|
output[num_decode_tokens:] = output_prefill
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
if self.running_in_graph:
|
if self.running_in_graph:
|
||||||
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
||||||
decode_k_nope, decode_k_pe,
|
decode_k_nope, decode_k_pe,
|
||||||
kv_cache, attn_metadata)
|
kv_cache, attn_metadata)
|
||||||
else:
|
else:
|
||||||
output[:num_decode_tokens] = self._forward_decode(
|
output_decode = self._forward_decode(decode_ql_nope,
|
||||||
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
|
decode_q_pe,
|
||||||
kv_cache, attn_metadata)
|
decode_k_nope,
|
||||||
|
decode_k_pe, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
|
if current_ms_metadata is not None:
|
||||||
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
|
output[:num_decode_tokens] = output_decode
|
||||||
|
current_ms_metadata.after_comm_event.record()
|
||||||
|
else:
|
||||||
|
output[:num_decode_tokens] = output_decode
|
||||||
|
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
@@ -107,6 +107,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to enable the trace recompiles from pytorch.
|
# Whether to enable the trace recompiles from pytorch.
|
||||||
"VLLM_ASCEND_TRACE_RECOMPILES":
|
"VLLM_ASCEND_TRACE_RECOMPILES":
|
||||||
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
||||||
|
"VLLM_ASCEND_ENABLE_DBO":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
||||||
# Whether to enable the model execute time observe profile. Disable it when
|
# Whether to enable the model execute time observe profile. Disable it when
|
||||||
# running vllm ascend in production environment.
|
# running vllm ascend in production environment.
|
||||||
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from vllm import ModelRegistry
|
from vllm import ModelRegistry
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs
|
||||||
|
|
||||||
|
|
||||||
def register_model():
|
def register_model():
|
||||||
|
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
|
||||||
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
||||||
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
||||||
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
|
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
|
||||||
@@ -22,9 +25,14 @@ def register_model():
|
|||||||
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
if envs.VLLM_ASCEND_ENABLE_DBO:
|
||||||
"DeepseekV2ForCausalLM",
|
ModelRegistry.register_model(
|
||||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
"DeepseekV2ForCausalLM",
|
||||||
|
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
||||||
|
else:
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"DeepseekV2ForCausalLM",
|
||||||
|
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"DeepseekV3ForCausalLM",
|
"DeepseekV3ForCausalLM",
|
||||||
|
|||||||
1118
vllm_ascend/models/deepseek_dbo.py
Normal file
1118
vllm_ascend/models/deepseek_dbo.py
Normal file
File diff suppressed because it is too large
Load Diff
0
vllm_ascend/multistream/__init__.py
Normal file
0
vllm_ascend/multistream/__init__.py
Normal file
29
vllm_ascend/multistream/base.py
Normal file
29
vllm_ascend/multistream/base.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class MSEventKey(Enum):
|
||||||
|
ATTN_COM_FINISH = 0
|
||||||
|
ATTN_AR_FINISH = 1
|
||||||
|
FFN_COM_FINISH = 2
|
||||||
|
FFN_AR_FINISH = 3
|
||||||
|
# events for MOE dispatch and combine
|
||||||
|
MOE_BEFORE_COMM = 4
|
||||||
|
MOE_AFTER_COMM = 5
|
||||||
|
# events for shared expert
|
||||||
|
MOE_SE_COMM_FINISH = 6
|
||||||
|
MOE_SE_COMP_FINISH = 7
|
||||||
|
MOE_GATE_FINISH = 8
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MSAttentionMetadataSplitConfig:
|
||||||
|
"""
|
||||||
|
micro batch split config for split attention metadata
|
||||||
|
"""
|
||||||
|
# micro batch num
|
||||||
|
num_micro_batches: int = 2
|
||||||
|
# split micro batches only when total tokens >= min_total_tokens_to_split
|
||||||
|
min_total_tokens_to_split: int = 256
|
||||||
|
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
|
||||||
|
min_prefill_tokens_to_split: int = 64
|
||||||
67
vllm_ascend/multistream/context.py
Normal file
67
vllm_ascend/multistream/context.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_ms_comm_context: Any = None
|
||||||
|
_cur_micro_batch_num: int = -1
|
||||||
|
_ms_layer_index_context: int = -1
|
||||||
|
_ms_metadata_context: Any = None
|
||||||
|
_ms_attn_metadata_context: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
|
||||||
|
attn_metadata: Any):
|
||||||
|
"""
|
||||||
|
set multistream layer context before transformer layers
|
||||||
|
"""
|
||||||
|
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||||
|
_ms_layer_index_context = start_layer
|
||||||
|
_ms_metadata_context = ms_metadata
|
||||||
|
_ms_attn_metadata_context = attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def reset_multistream_layer_context():
|
||||||
|
"""
|
||||||
|
reset multistream layer context
|
||||||
|
"""
|
||||||
|
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||||
|
_ms_layer_index_context = -1
|
||||||
|
_ms_metadata_context = None
|
||||||
|
_ms_attn_metadata_context = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_multistream_layer_context():
|
||||||
|
"""
|
||||||
|
get multistream layer context
|
||||||
|
"""
|
||||||
|
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||||
|
|
||||||
|
|
||||||
|
def advance_step_multistream_layer_context():
|
||||||
|
"""
|
||||||
|
advance multistream layer index context
|
||||||
|
"""
|
||||||
|
global _ms_layer_index_context
|
||||||
|
_ms_layer_index_context += 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_multistream_comm_context() -> Any:
|
||||||
|
"""Get the current comm forward context."""
|
||||||
|
return _ms_comm_context
|
||||||
|
|
||||||
|
|
||||||
|
def get_multistream_microbatch_context() -> int:
|
||||||
|
return _cur_micro_batch_num
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_multistream_context(context: Any, micro_batch_num: int):
|
||||||
|
"""A context manager that stores the current comm forward context,
|
||||||
|
can be attention metadata, etc."""
|
||||||
|
global _ms_comm_context, _cur_micro_batch_num
|
||||||
|
_ms_comm_context = context
|
||||||
|
_cur_micro_batch_num = micro_batch_num
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_ms_comm_context = None
|
||||||
|
_cur_micro_batch_num = -1
|
||||||
26
vllm_ascend/multistream/decorator.py
Normal file
26
vllm_ascend/multistream/decorator.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .context import (get_multistream_layer_context,
|
||||||
|
get_multistream_microbatch_context)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# vllm v1 use get_forward_context to get the attn_metadata,
|
||||||
|
# we can use this decorator to update the attn metadata
|
||||||
|
def set_multistream_support():
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
|
||||||
|
def wrapper():
|
||||||
|
context = func()
|
||||||
|
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
|
||||||
|
)
|
||||||
|
micro_batch_num = get_multistream_microbatch_context()
|
||||||
|
if layer_index != -1 and micro_batch_num != -1:
|
||||||
|
context.attn_metadata = attn_metadata[micro_batch_num]
|
||||||
|
return context
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
61
vllm_ascend/multistream/layers.py
Normal file
61
vllm_ascend/multistream/layers.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from .base import MSEventKey
|
||||||
|
from .context import (get_multistream_layer_context,
|
||||||
|
reset_multistream_layer_context,
|
||||||
|
set_multistream_layer_context)
|
||||||
|
from .metadata import MultiStreamMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStreamPreTransformerLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||||
|
super().__init__()
|
||||||
|
self.multistream_metadata = multistream_metadata
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
intput_tensors: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
if self.multistream_metadata is None or attn_metadata is None:
|
||||||
|
set_multistream_layer_context(-1, None, None)
|
||||||
|
return attn_metadata, intput_tensors
|
||||||
|
# TODO add attn_metadata management
|
||||||
|
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
|
||||||
|
attn_metadata, intput_tensors)
|
||||||
|
if do_ms:
|
||||||
|
set_multistream_layer_context(
|
||||||
|
self.multistream_metadata.start_layer,
|
||||||
|
self.multistream_metadata, attn_metadata)
|
||||||
|
else:
|
||||||
|
set_multistream_layer_context(-1, None, None)
|
||||||
|
return attn_metadata, intput_tensors
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStreamPostTransformerLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||||
|
super().__init__()
|
||||||
|
self.multistream_metadata = multistream_metadata
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_tensors: Union[List[Tuple[torch.Tensor]],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]]],
|
||||||
|
wait_layer_index: Optional[int] = None):
|
||||||
|
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
|
||||||
|
return input_tensors
|
||||||
|
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
|
||||||
|
)
|
||||||
|
if layer_index >= 0:
|
||||||
|
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
|
||||||
|
self.multistream_metadata.try_wait_event(
|
||||||
|
true_wait_layer,
|
||||||
|
self.multistream_metadata.ms_config.num_micro_batches - 1,
|
||||||
|
MSEventKey.FFN_AR_FINISH)
|
||||||
|
reset_multistream_layer_context()
|
||||||
|
return self.multistream_metadata.merge_micro_batches(input_tensors)
|
||||||
182
vllm_ascend/multistream/metadata.py
Normal file
182
vllm_ascend/multistream/metadata.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||||
|
|
||||||
|
from .base import MSAttentionMetadataSplitConfig, MSEventKey
|
||||||
|
|
||||||
|
|
||||||
|
def split_micro_batches_tensors(input_tensors,
|
||||||
|
split_index: int,
|
||||||
|
keys: Optional[List[str]] = None):
|
||||||
|
if isinstance(input_tensors, list):
|
||||||
|
micro_batches = []
|
||||||
|
for tensor in input_tensors:
|
||||||
|
if tensor is None:
|
||||||
|
micro_batches.append([None, None])
|
||||||
|
else:
|
||||||
|
micro_batches.append(
|
||||||
|
[tensor[:split_index], tensor[split_index:]])
|
||||||
|
return micro_batches
|
||||||
|
elif isinstance(input_tensors, torch.Tensor):
|
||||||
|
return [input_tensors[:split_index], input_tensors[split_index:]]
|
||||||
|
elif input_tensors is None:
|
||||||
|
return [None, None]
|
||||||
|
elif isinstance(input_tensors, Dict):
|
||||||
|
assert keys is not None
|
||||||
|
micro_batches_pre = {}
|
||||||
|
for key in keys:
|
||||||
|
micro_batches_pre[key] = input_tensors[key][:split_index]
|
||||||
|
micro_batches_post = {}
|
||||||
|
for key in keys:
|
||||||
|
micro_batches_post[key] = input_tensors[key][split_index:]
|
||||||
|
return [micro_batches_pre, micro_batches_post]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiStreamStepMetadata:
|
||||||
|
comm_stream: torch.npu.Stream = None
|
||||||
|
before_comm_event: torch.npu.Event = None
|
||||||
|
after_comm_event: torch.npu.Event = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiStreamConfig:
|
||||||
|
"""Controls the behavior of multi-stream models."""
|
||||||
|
min_total_tokens_to_split: int = 256
|
||||||
|
min_prefill_tokens_to_split: int = 64
|
||||||
|
num_micro_batches: int = 2
|
||||||
|
imbalance_ratio: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStreamMetadata:
|
||||||
|
# direct stream
|
||||||
|
calculate_stream = None
|
||||||
|
# delay stream
|
||||||
|
communicate_stream = None
|
||||||
|
# events
|
||||||
|
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
|
||||||
|
# multi-stream-flag
|
||||||
|
enable_multi_stream: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
calculate_stream: torch.npu.Stream,
|
||||||
|
communicate_stream: torch.npu.Stream,
|
||||||
|
start_layer: int,
|
||||||
|
end_layer: int,
|
||||||
|
event_keys: List[MSEventKey],
|
||||||
|
multistream_config: Optional[MultiStreamConfig],
|
||||||
|
causal_lm: bool = True,
|
||||||
|
):
|
||||||
|
self.calculate_stream = calculate_stream
|
||||||
|
self.communicate_stream = communicate_stream
|
||||||
|
self.start_layer = start_layer
|
||||||
|
self.end_layer = end_layer
|
||||||
|
self.ms_config = multistream_config
|
||||||
|
self.causal_lm = causal_lm
|
||||||
|
self._build_events(event_keys)
|
||||||
|
self._build_ms_split_config()
|
||||||
|
|
||||||
|
def _build_events(self, event_keys):
|
||||||
|
if self.ms_config is not None:
|
||||||
|
for i in range(self.start_layer - 1, self.end_layer):
|
||||||
|
self.ms_events[i] = {}
|
||||||
|
for j in range(self.ms_config.num_micro_batches):
|
||||||
|
self.ms_events[i][j] = {}
|
||||||
|
for key in event_keys:
|
||||||
|
self.ms_events[i][j][key] = torch.npu.Event()
|
||||||
|
|
||||||
|
def _build_ms_split_config(self):
|
||||||
|
if self.ms_config is not None:
|
||||||
|
self.ms_split_config = MSAttentionMetadataSplitConfig(
|
||||||
|
num_micro_batches=self.ms_config.num_micro_batches,
|
||||||
|
min_total_tokens_to_split=self.ms_config.
|
||||||
|
min_total_tokens_to_split,
|
||||||
|
min_prefill_tokens_to_split=self.ms_config.
|
||||||
|
min_prefill_tokens_to_split,
|
||||||
|
)
|
||||||
|
|
||||||
|
def try_wait_event(self, layer_index: int, micro_batch_index: int,
|
||||||
|
event_key: MSEventKey):
|
||||||
|
self.ms_events[layer_index][micro_batch_index][event_key].wait()
|
||||||
|
|
||||||
|
def try_record_event(self, layer_index: int, micro_batch_index: int,
|
||||||
|
event_key: MSEventKey):
|
||||||
|
self.ms_events[layer_index][micro_batch_index][event_key].record()
|
||||||
|
|
||||||
|
def split_micro_batch(
|
||||||
|
self,
|
||||||
|
attn_metadata: "AscendMLAMetadata",
|
||||||
|
intput_tensors: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
intermediate_tensors_keys: Optional[List[str]] = None,
|
||||||
|
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
|
||||||
|
List[torch.Tensor], List[List[torch.Tensor]]], Union[
|
||||||
|
IntermediateTensors, List[IntermediateTensors]]]:
|
||||||
|
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
|
||||||
|
self.ms_split_config)
|
||||||
|
if len(attn_metadata_list) == 1:
|
||||||
|
return False, attn_metadata_list[
|
||||||
|
0], intput_tensors, intermediate_tensors
|
||||||
|
split_index = attn_metadata_list[0].slot_mapping.shape[0]
|
||||||
|
input_tensors = split_micro_batches_tensors(intput_tensors,
|
||||||
|
split_index)
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inter_tensors_list = split_micro_batches_tensors(
|
||||||
|
intermediate_tensors.tensors, split_index,
|
||||||
|
intermediate_tensors_keys)
|
||||||
|
intermediate_tensors = [
|
||||||
|
IntermediateTensors(inter_tensors)
|
||||||
|
for inter_tensors in inter_tensors_list
|
||||||
|
]
|
||||||
|
return True, attn_metadata_list, input_tensors, intermediate_tensors
|
||||||
|
|
||||||
|
def merge_micro_batches(
|
||||||
|
self, input_tensors: Union[List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]]]
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
|
||||||
|
return input_tensors
|
||||||
|
batch: List[Optional[torch.Tensor]] = []
|
||||||
|
for tensors in input_tensors:
|
||||||
|
if tensors is None or tensors[0] is None:
|
||||||
|
batch.append(None)
|
||||||
|
else:
|
||||||
|
batch.append(torch.cat(tensors, dim=0))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def make_multistream_metadata_ds(
|
||||||
|
start_layer: int,
|
||||||
|
end_layer: int,
|
||||||
|
causal_lm: bool = True,
|
||||||
|
multistream_config: Optional[MultiStreamConfig] = None,
|
||||||
|
):
|
||||||
|
if multistream_config is None:
|
||||||
|
return None
|
||||||
|
event_keylist = [
|
||||||
|
MSEventKey.ATTN_COM_FINISH,
|
||||||
|
MSEventKey.ATTN_AR_FINISH,
|
||||||
|
MSEventKey.FFN_COM_FINISH,
|
||||||
|
MSEventKey.FFN_AR_FINISH,
|
||||||
|
MSEventKey.MOE_BEFORE_COMM,
|
||||||
|
MSEventKey.MOE_AFTER_COMM,
|
||||||
|
MSEventKey.MOE_SE_COMM_FINISH,
|
||||||
|
MSEventKey.MOE_SE_COMP_FINISH,
|
||||||
|
MSEventKey.MOE_GATE_FINISH,
|
||||||
|
]
|
||||||
|
return MultiStreamMetadata(
|
||||||
|
calculate_stream=torch.npu.current_stream(),
|
||||||
|
communicate_stream=torch.npu.Stream(),
|
||||||
|
start_layer=start_layer,
|
||||||
|
end_layer=end_layer,
|
||||||
|
multistream_config=multistream_config,
|
||||||
|
event_keys=event_keylist,
|
||||||
|
causal_lm=causal_lm,
|
||||||
|
)
|
||||||
245
vllm_ascend/multistream/ms_split.py
Normal file
245
vllm_ascend/multistream/ms_split.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
|
||||||
|
from .base import MSAttentionMetadataSplitConfig
|
||||||
|
|
||||||
|
|
||||||
|
def compute_split_seq_index(
|
||||||
|
query_lens: Optional[list[int]],
|
||||||
|
attn_state: AscendAttentionState,
|
||||||
|
num_tokens: int,
|
||||||
|
imbalance_ratio: float = 0.1,
|
||||||
|
) -> list[int]:
|
||||||
|
if attn_state != AscendAttentionState.DecodeOnly:
|
||||||
|
assert query_lens is not None
|
||||||
|
total_tokens = sum(query_lens)
|
||||||
|
# the first index in last split
|
||||||
|
tokens, split_index = 0, 0
|
||||||
|
for value in query_lens:
|
||||||
|
tokens += value
|
||||||
|
split_index += 1
|
||||||
|
if tokens >= total_tokens // 2:
|
||||||
|
# check the current split index
|
||||||
|
if abs(tokens -
|
||||||
|
total_tokens // 2) < total_tokens * imbalance_ratio:
|
||||||
|
return [tokens, split_index]
|
||||||
|
# check the previous split index
|
||||||
|
elif abs(tokens - total_tokens // 2 -
|
||||||
|
value) < total_tokens * imbalance_ratio:
|
||||||
|
return [tokens - value, split_index - 1]
|
||||||
|
# fail to split if it is imbalanced
|
||||||
|
# TODO: split tokens in seq
|
||||||
|
else:
|
||||||
|
return [0, 0]
|
||||||
|
else:
|
||||||
|
tokens = num_tokens // 2
|
||||||
|
return [tokens, tokens]
|
||||||
|
return [0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def split_attn_tensor_type(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
index: int,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
return [input_tensor[:index], input_tensor[index:]]
|
||||||
|
|
||||||
|
|
||||||
|
def split_attn_int_type(
|
||||||
|
var: int,
|
||||||
|
index: int,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
return [min(var, index), max(var - index, 0)]
|
||||||
|
|
||||||
|
|
||||||
|
def model_input_split_v1_mla_attn(
|
||||||
|
attn_metadata,
|
||||||
|
_metadata_cls,
|
||||||
|
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||||
|
) -> List[Any]:
|
||||||
|
assert 0 < ms_split_config.num_micro_batches < 3
|
||||||
|
if attn_metadata is None:
|
||||||
|
return [attn_metadata]
|
||||||
|
[token_index,
|
||||||
|
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
||||||
|
attn_metadata.attn_state,
|
||||||
|
attn_metadata.num_decode_tokens)
|
||||||
|
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||||
|
attn_metadata.query_lens):
|
||||||
|
return [attn_metadata]
|
||||||
|
|
||||||
|
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
||||||
|
dtype=int)
|
||||||
|
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
||||||
|
if attn_metadata.num_prefills > 0:
|
||||||
|
prefill_query_start_loc = np.zeros(
|
||||||
|
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
||||||
|
np.cumsum(attn_metadata.prefill.query_lens,
|
||||||
|
out=prefill_query_start_loc[1:])
|
||||||
|
|
||||||
|
# split attn metadata
|
||||||
|
[slot_mapping_pre,
|
||||||
|
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
||||||
|
token_index)
|
||||||
|
[num_decodes_pre,
|
||||||
|
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
||||||
|
seq_index)
|
||||||
|
[num_decode_tokens_pre, num_decode_tokens_post
|
||||||
|
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
||||||
|
[num_prefills_pre, num_prefills_post
|
||||||
|
] = split_attn_int_type(attn_metadata.num_prefills,
|
||||||
|
max(0, seq_index - attn_metadata.num_decodes))
|
||||||
|
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||||
|
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||||
|
|
||||||
|
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||||
|
query_start_loc_post = deepcopy(
|
||||||
|
attn_metadata.query_start_loc[seq_index:]
|
||||||
|
) - attn_metadata.query_start_loc[seq_index]
|
||||||
|
[block_table_pre,
|
||||||
|
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||||
|
seq_index)
|
||||||
|
|
||||||
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||||
|
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||||
|
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||||
|
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
# should be none in decode only state
|
||||||
|
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||||
|
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
||||||
|
else:
|
||||||
|
# chunked prefill
|
||||||
|
if num_prefills_pre > 0:
|
||||||
|
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||||
|
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
||||||
|
seq_lens_pre)].contiguous()
|
||||||
|
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||||
|
attn_mask_post = attn_metadata.attn_mask[
|
||||||
|
token_index:, :max(seq_lens_post)].contiguous()
|
||||||
|
else:
|
||||||
|
attn_state_pre = AscendAttentionState.DecodeOnly
|
||||||
|
attn_mask_pre = None
|
||||||
|
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||||
|
attn_mask_post = attn_metadata.attn_mask[
|
||||||
|
token_index:, :max(seq_lens_post)].contiguous()
|
||||||
|
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||||
|
AscendMLAPrefillMetadata)
|
||||||
|
if num_prefills_pre > 0:
|
||||||
|
# split metadata.prefill
|
||||||
|
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
||||||
|
attn_metadata.prefill.input_positions,
|
||||||
|
token_index - attn_metadata.num_decode_tokens)
|
||||||
|
[block_tables_pre, block_tables_post
|
||||||
|
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
||||||
|
seq_index - attn_metadata.num_decodes)
|
||||||
|
[prefill_query_lens_pre, prefill_query_lens_post
|
||||||
|
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
||||||
|
seq_index - attn_metadata.num_decodes)
|
||||||
|
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
||||||
|
seq_index
|
||||||
|
+
|
||||||
|
1 -
|
||||||
|
attn_metadata
|
||||||
|
.
|
||||||
|
num_decodes]
|
||||||
|
prefill_query_start_loc_post = deepcopy(
|
||||||
|
attn_metadata.prefill.query_start_loc[seq_index -
|
||||||
|
attn_metadata.num_decodes:]
|
||||||
|
) - attn_metadata.prefill.query_start_loc[seq_index -
|
||||||
|
attn_metadata.num_decodes]
|
||||||
|
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
||||||
|
context_len_post = seq_lens_post
|
||||||
|
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
||||||
|
prefill_max_query_len_post = max(prefill_query_lens_post)
|
||||||
|
prefill_pre = AscendMLAPrefillMetadata(
|
||||||
|
attn_mask=attn_mask_pre,
|
||||||
|
query_lens=prefill_query_lens_pre,
|
||||||
|
seq_lens=seq_lens_pre,
|
||||||
|
query_start_loc=prefill_query_start_loc_pre,
|
||||||
|
input_positions=input_positions_pre,
|
||||||
|
context_lens=context_len_pre,
|
||||||
|
block_table=block_tables_pre,
|
||||||
|
max_query_len=prefill_max_query_len_pre,
|
||||||
|
max_seq_lens=context_len_pre.max().item(),
|
||||||
|
)
|
||||||
|
prefill_post = AscendMLAPrefillMetadata(
|
||||||
|
attn_mask=attn_mask_post,
|
||||||
|
query_lens=prefill_query_lens_post,
|
||||||
|
seq_lens=seq_lens_post,
|
||||||
|
query_start_loc=prefill_query_start_loc_post,
|
||||||
|
input_positions=input_positions_post,
|
||||||
|
context_lens=context_len_post,
|
||||||
|
block_table=block_tables_post,
|
||||||
|
max_query_len=prefill_max_query_len_post,
|
||||||
|
max_seq_lens=context_len_post.max().item(),
|
||||||
|
)
|
||||||
|
decode_pre = attn_metadata.decode
|
||||||
|
decode_post = None
|
||||||
|
else:
|
||||||
|
# prefill is None, split metadata.decode
|
||||||
|
[input_positions_pre, input_positions_post
|
||||||
|
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
||||||
|
token_index)
|
||||||
|
[block_tables_pre, block_tables_post
|
||||||
|
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
||||||
|
seq_index)
|
||||||
|
[decode_seq_lens_pre,
|
||||||
|
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||||
|
decode_pre = AscendMLADecodeMetadata(
|
||||||
|
input_positions=input_positions_pre,
|
||||||
|
block_table=block_tables_pre,
|
||||||
|
seq_lens=decode_seq_lens_pre,
|
||||||
|
max_seq_lens=max(decode_seq_lens_pre),
|
||||||
|
seq_lens_list=decode_seq_lens_pre.tolist(),
|
||||||
|
)
|
||||||
|
decode_post = AscendMLADecodeMetadata(
|
||||||
|
input_positions=input_positions_post,
|
||||||
|
block_table=block_tables_post,
|
||||||
|
seq_lens=decode_seq_lens_post,
|
||||||
|
max_seq_lens=max(decode_seq_lens_post),
|
||||||
|
seq_lens_list=decode_seq_lens_post.tolist(),
|
||||||
|
)
|
||||||
|
prefill_pre = None
|
||||||
|
prefill_post = attn_metadata.prefill
|
||||||
|
# construct metadata
|
||||||
|
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||||
|
attention_metadata_pre = _metadata_cls(
|
||||||
|
num_actual_tokens=token_index,
|
||||||
|
num_input_tokens=token_index,
|
||||||
|
head_dim=attn_metadata.head_dim,
|
||||||
|
slot_mapping=slot_mapping_pre,
|
||||||
|
seq_lens=seq_lens_pre,
|
||||||
|
query_start_loc=query_start_loc_pre,
|
||||||
|
block_tables=block_table_pre,
|
||||||
|
num_decodes=num_decodes_pre,
|
||||||
|
num_prefills=num_prefills_pre,
|
||||||
|
num_decode_tokens=num_decode_tokens_pre,
|
||||||
|
attn_state=attn_state_pre,
|
||||||
|
attn_mask=attn_mask_pre,
|
||||||
|
prefill=prefill_pre,
|
||||||
|
decode=decode_pre,
|
||||||
|
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
|
||||||
|
)
|
||||||
|
attention_metadata_post = _metadata_cls(
|
||||||
|
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||||
|
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
||||||
|
head_dim=attn_metadata.head_dim,
|
||||||
|
slot_mapping=slot_mapping_post,
|
||||||
|
seq_lens=seq_lens_post,
|
||||||
|
query_start_loc=query_start_loc_post,
|
||||||
|
block_tables=block_table_post,
|
||||||
|
num_decodes=num_decodes_post,
|
||||||
|
num_prefills=num_prefills_post,
|
||||||
|
num_decode_tokens=num_decode_tokens_post,
|
||||||
|
attn_mask=attn_mask_post,
|
||||||
|
attn_state=attn_state_post,
|
||||||
|
prefill=prefill_post,
|
||||||
|
decode=decode_post,
|
||||||
|
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
|
||||||
|
)
|
||||||
|
return [attention_metadata_pre, attention_metadata_post]
|
||||||
@@ -1151,3 +1151,32 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if self.enable_multistream_shared_expert and not is_prefill:
|
if self.enable_multistream_shared_expert and not is_prefill:
|
||||||
return hidden_states, shared_output
|
return hidden_states, shared_output
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
# ----------------------------------------- TBO-related --------------------------------------------
|
||||||
|
|
||||||
|
def _forward_ms_fused_moe_comp(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_prefill: bool,
|
||||||
|
real_top_k,
|
||||||
|
enable_force_load_balance: bool = False,
|
||||||
|
):
|
||||||
|
hidden_states = self.quant_method.apply(
|
||||||
|
layer=self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=real_top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
global_num_experts=self.global_num_experts,
|
||||||
|
expert_map=self.expert_map,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
is_prefill=is_prefill,
|
||||||
|
enable_force_load_balance=enable_force_load_balance)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user