[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it?
moe support for llama4 and mllama4 in vllm-ascend
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
start sever:
python -m vllm.entrypoints.openai.api_server --model
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \
--max-num-seqs=256 \
--max-model-len=8192 \
--tensor-parallel-size=8 \
--block-size=128 \
--dtype bfloat16 \
--host=0.0.0.0 \
--port=8000 \
--gpu-memory-utilization=0.9 \
--trust-remote-code
client:
python online_server.py --model-path
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct
--image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip
7.242.108.253 --served-port 8000 --text "what is the content of this
image?"
result:
{'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object':
'chat.completion', 'created': 1747056823, 'model':
'/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct',
'choices': [{'index': 0, 'message': {'role': 'assistant',
'reasoning_content': None, 'content': 'The image depicts a tower, likely
Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is
white and has a distinctive shape, with a large sphere at the top and a
long, thin spire extending from it. The branches of the cherry blossom
tree are in the foreground, with pink flowers blooming on them. The
background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:**
White, spherical shape at the top, long thin spire\n', 'tool_calls':
[]}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}],
'usage': {'prompt_tokens': 2340, 'total_tokens': 2440,
'completion_tokens': 100, 'prompt_tokens_details': None},
'prompt_logprobs': None}
Signed-off-by: chenxu <chenxu68@huawei.com>
Co-authored-by: chenxu <chenxu68@huawei.com>
Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
@@ -708,6 +708,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|||||||
@@ -55,13 +55,15 @@ def forward_oot(
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(
|
||||||
w1=layer.w13_weight,
|
hidden_states=x,
|
||||||
w2=layer.w2_weight,
|
w1=layer.w13_weight,
|
||||||
topk_weights=topk_weights,
|
w2=layer.w2_weight,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
top_k=top_k,
|
topk_ids=topk_ids,
|
||||||
expert_map=expert_map)
|
top_k=top_k,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
|
|
||||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ def fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
expert_map: torch.Tensor = None,
|
expert_map: torch.Tensor = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Fused experts with top-k routing.
|
Fused experts with top-k routing.
|
||||||
@@ -191,6 +192,15 @@ def fused_experts(
|
|||||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||||
# ], "Only float32, float16, and bfloat16 are supported"
|
# ], "Only float32, float16, and bfloat16 are supported"
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
assert (topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
_, topk = topk_weights.shape
|
||||||
|
assert (
|
||||||
|
topk == 1
|
||||||
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
|
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
# Generate token indices and flatten
|
# Generate token indices and flatten
|
||||||
token_indices = (torch.arange(num_tokens,
|
token_indices = (torch.arange(num_tokens,
|
||||||
@@ -292,6 +302,8 @@ def fused_experts(
|
|||||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||||
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
||||||
else:
|
else:
|
||||||
|
scales = torch.ones_like(
|
||||||
|
topk_weights) if apply_router_weight_on_input else topk_weights
|
||||||
# TODO: Reorder device memory 2 times here, replace the current
|
# TODO: Reorder device memory 2 times here, replace the current
|
||||||
# implementation here when suitable operators become available.
|
# implementation here when suitable operators become available.
|
||||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||||
@@ -299,7 +311,7 @@ def fused_experts(
|
|||||||
skip1=None,
|
skip1=None,
|
||||||
skip2=None,
|
skip2=None,
|
||||||
bias=None,
|
bias=None,
|
||||||
scales=topk_weights,
|
scales=scales,
|
||||||
expanded_src_to_dst_row=expanded_row_idx,
|
expanded_src_to_dst_row=expanded_row_idx,
|
||||||
export_for_source_row=topk_ids,
|
export_for_source_row=topk_ids,
|
||||||
)
|
)
|
||||||
@@ -366,9 +378,6 @@ def select_experts(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If an unsupported scoring function is provided.
|
ValueError: If an unsupported scoring function is provided.
|
||||||
"""
|
"""
|
||||||
if custom_routing_function is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Custom routing function is not supported now")
|
|
||||||
|
|
||||||
if scoring_func == "softmax":
|
if scoring_func == "softmax":
|
||||||
# NOTE: vLLM use dtype=torch.float here
|
# NOTE: vLLM use dtype=torch.float here
|
||||||
@@ -405,9 +414,18 @@ def select_experts(
|
|||||||
k=top_k,
|
k=top_k,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
sorted=False)
|
sorted=False)
|
||||||
else:
|
elif custom_routing_function is None:
|
||||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize)
|
||||||
|
# Required by npu_moe_init_routing
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
# Required by npu_moe_init_routing
|
# Required by npu_moe_init_routing
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
|||||||
Reference in New Issue
Block a user