[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it?
moe support for llama4 and mllama4 in vllm-ascend
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
start sever:
python -m vllm.entrypoints.openai.api_server --model
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \
--max-num-seqs=256 \
--max-model-len=8192 \
--tensor-parallel-size=8 \
--block-size=128 \
--dtype bfloat16 \
--host=0.0.0.0 \
--port=8000 \
--gpu-memory-utilization=0.9 \
--trust-remote-code
client:
python online_server.py --model-path
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct
--image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip
7.242.108.253 --served-port 8000 --text "what is the content of this
image?"
result:
{'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object':
'chat.completion', 'created': 1747056823, 'model':
'/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct',
'choices': [{'index': 0, 'message': {'role': 'assistant',
'reasoning_content': None, 'content': 'The image depicts a tower, likely
Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is
white and has a distinctive shape, with a large sphere at the top and a
long, thin spire extending from it. The branches of the cherry blossom
tree are in the foreground, with pink flowers blooming on them. The
background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:**
White, spherical shape at the top, long thin spire\n', 'tool_calls':
[]}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}],
'usage': {'prompt_tokens': 2340, 'total_tokens': 2440,
'completion_tokens': 100, 'prompt_tokens_details': None},
'prompt_logprobs': None}
Signed-off-by: chenxu <chenxu68@huawei.com>
Co-authored-by: chenxu <chenxu68@huawei.com>
Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
@@ -708,6 +708,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
|
||||
@@ -174,6 +174,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
|
||||
@@ -55,13 +55,15 @@ def forward_oot(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
|
||||
@@ -153,6 +153,7 @@ def fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused experts with top-k routing.
|
||||
@@ -191,6 +192,15 @@ def fused_experts(
|
||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
# ], "Only float32, float16, and bfloat16 are supported"
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
token_indices = (torch.arange(num_tokens,
|
||||
@@ -292,6 +302,8 @@ def fused_experts(
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
||||
else:
|
||||
scales = torch.ones_like(
|
||||
topk_weights) if apply_router_weight_on_input else topk_weights
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
@@ -299,7 +311,7 @@ def fused_experts(
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
scales=scales,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
@@ -366,9 +378,6 @@ def select_experts(
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
if custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
"Custom routing function is not supported now")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
# NOTE: vLLM use dtype=torch.float here
|
||||
@@ -405,9 +414,18 @@ def select_experts(
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
else:
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user