forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
# 背景
|
||||
|
||||
此示例用于在vLLM中演示chunked parallel pipeline功能,通过mlu_hijck机制将需要修改的代码劫持到当前目录,避免修改主仓库代码。
|
||||
|
||||
# 支持模型
|
||||
|
||||
- LlamaForCausalLM
|
||||
- CustomForCausalLM
|
||||
|
||||
# Demo运行方式
|
||||
|
||||
当前Chunked Parallel Pipeline仅支持通过AsyncLLMEngine方式用paged mode运行。
|
||||
|
||||
- 设置环境变量
|
||||
|
||||
```bash
|
||||
export CHUNKED_PIPELINE_PARALLEL_EN=true
|
||||
```
|
||||
|
||||
- 启动server进程
|
||||
```bash
|
||||
# 设置engine超时阈值。
|
||||
export VLLM_ENGINE_ITERATION_TIMEOUT_S=180
|
||||
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--port ${PORT} \
|
||||
--model ${MODEL_PATH} \
|
||||
--swap-space 16 \
|
||||
--pipeline-parallel-size ${PP_SIZE} \
|
||||
--max-num-batched-tokens ${MAX_TOKENS_NUM} \
|
||||
--enable-chunked-prefill \
|
||||
--worker-use-ray \
|
||||
--enforce-eager
|
||||
```
|
||||
|
||||
- 启动client进程
|
||||
这里以随机数为例,可以选用真实数据集。
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model ${MODEL_PATH} \
|
||||
--dataset-name random \
|
||||
--num-prompts ${NUM_PROMPT} \
|
||||
--port ${PORT} \
|
||||
--random-input-len ${INPUT_LEN} \
|
||||
--random-output-len 1 \
|
||||
--request-rate inf
|
||||
```
|
||||
@@ -0,0 +1 @@
|
||||
from . import parallel_state
|
||||
@@ -0,0 +1,223 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""vLLM distributed state.
|
||||
It takes over the control of the distributed environment from PyTorch.
|
||||
The typical workflow is:
|
||||
|
||||
- call `init_distributed_environment` to initialize the distributed environment.
|
||||
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
||||
initialize the model parallel groups.
|
||||
|
||||
- any code dealing with the distributed stuff
|
||||
|
||||
- call `destroy_model_parallel` to destroy the model parallel groups.
|
||||
- call `destroy_distributed_environment` to destroy the distributed environment.
|
||||
|
||||
If you only need to use the distributed environment without model/pipeline
|
||||
parallelism, you can skip the model parallel initialization and destruction
|
||||
steps.
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
_split_tensor_dict,
|
||||
TensorMetadata,
|
||||
)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__distributed__GroupCoordinator__send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: Skip send tensor metadata list.
|
||||
"""
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
_, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
if (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0):
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=metadata_group)
|
||||
else:
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: Modify send to isend.
|
||||
"""
|
||||
# use group for GPU tensors
|
||||
torch.distributed.isend(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=group)
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: Add a parameter `recv_metadata_list`.
|
||||
"""
|
||||
def vllm__distributed__GroupCoordinator__recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
recv_metadata_list: List[Tuple[str, Any]] = [],
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: Skip receiving tensor metadata list.
|
||||
"""
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
tensor_dict: Dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
tensor = tensor.reshape(all_gather_size,
|
||||
-1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: Modify recv to irecv, and wait to finish.
|
||||
"""
|
||||
# use group for GPU tensors
|
||||
req = torch.distributed.irecv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
req.wait()
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
GroupCoordinator,
|
||||
GroupCoordinator.send_tensor_dict,
|
||||
vllm__distributed__GroupCoordinator__send_tensor_dict,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
GroupCoordinator,
|
||||
GroupCoordinator.recv_tensor_dict,
|
||||
vllm__distributed__GroupCoordinator__recv_tensor_dict,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from . import async_llm_engine
|
||||
@@ -0,0 +1,310 @@
|
||||
import asyncio
|
||||
from typing import (List, Optional, Union)
|
||||
|
||||
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S as ENGINE_ITERATION_TIMEOUT_S
|
||||
from vllm.core.scheduler import ScheduledSequenceGroup
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroup, SequenceGroupMetadata
|
||||
from vllm.engine.async_llm_engine import (_AsyncLLMEngine, AsyncLLMEngine)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__engine__async_llm_engine___AsyncLLMEngine____init__(self, *args, **kwargs):
|
||||
LLMEngine.__init__(self, *args, **kwargs)
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: Add a member variable to record parallel chunked prefill tasks,
|
||||
in which each member means (virtual_engine -> {req_id: task_list})
|
||||
"""
|
||||
self.step_tasks = [dict() for _ in range(len(self.scheduler))]
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
|
||||
def _update_scheduler_status(
|
||||
self,
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> None:
|
||||
"""Update scheduler status after emitting prefill task.
|
||||
|
||||
For chunked pipeline parallel, since chunked prefill tasks
|
||||
are executed asynchronously, we update scheduler status once
|
||||
tasks are emited.
|
||||
"""
|
||||
# Update the scheduled sequence groups.
|
||||
for scheduled_seq_group, seq_group_meta in zip(
|
||||
scheduled_seq_groups, seq_group_metadata_list):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
async def vllm__engine__async_llm_engine___AsyncLLMEngine__step_async(
|
||||
self, virtual_engine: int
|
||||
) -> Optional[List[Union[RequestOutput, EmbeddingRequestOutput]]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
This function performs one decoding iteration of the engine. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
# these are cached outputs from previous iterations. None if on first
|
||||
# iteration
|
||||
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# Clear outputs for each new scheduler iteration
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||
# will cause one virtual engine's microbatch to block the pipeline.
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(virtual_engine)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
# Execute the model.
|
||||
"""
|
||||
=============================
|
||||
Modifies by vllm_mlu
|
||||
=============================
|
||||
@brief: for chunked prefill tasks except the final task for a single
|
||||
request, create them asynchronously. And for the last prefill task,
|
||||
gather all previous tasks and get the final output.
|
||||
"""
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
assert len(seq_group_metadata_list) == 1, \
|
||||
"Currently we only support schedule single batch in " \
|
||||
"prefill stage for chunked pipeline parallel."
|
||||
token_chunk_size = seq_group_metadata_list[0].token_chunk_size
|
||||
seq_data = list(seq_group_metadata_list[0].seq_data.values())[0]
|
||||
prefill_loc = seq_data.get_num_computed_tokens()
|
||||
task = asyncio.create_task(
|
||||
self.model_executor.execute_model_async(execute_model_req, [prefill_loc], [token_chunk_size]))
|
||||
request_id = seq_group_metadata_list[0].request_id
|
||||
self.step_tasks[virtual_engine].setdefault(request_id, []).append(task)
|
||||
|
||||
# Gather point: if all prefill tasks for current sequence group
|
||||
# have been dispatched, we wait all prompt tasks and get the
|
||||
# final output.
|
||||
seq_len = seq_data.get_len()
|
||||
if token_chunk_size + prefill_loc == seq_len:
|
||||
outputs = await asyncio.gather(*self.step_tasks[virtual_engine][request_id])
|
||||
outputs = outputs[-1]
|
||||
else:
|
||||
# Since prefill stage has not been completely finished, we
|
||||
# just update scheduler and sequence status and return None.
|
||||
_update_scheduler_status(self, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
return None
|
||||
else:
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
outputs = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||
else:
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
outputs = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# Clear the cache if we have finished all the steps
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
ctx.append_output(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
is_async=allow_async_output_proc,
|
||||
is_last_step=True,
|
||||
is_first_step_output=is_first_step_output)
|
||||
|
||||
if outputs and allow_async_output_proc:
|
||||
assert len(
|
||||
outputs
|
||||
) == 1, "Async postprocessor expects only a single output set"
|
||||
self._advance_to_next_step(
|
||||
outputs[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, outputs)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
else:
|
||||
# Multi-step case
|
||||
return ctx.request_outputs
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
return ctx.request_outputs
|
||||
|
||||
async def vllm__engine__async_llm_engine__AsyncLLMEngine__engine_step(
|
||||
self, virtual_engine: int
|
||||
) -> bool:
|
||||
"""Kick the engine to process the waiting requests.
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
|
||||
new_requests, aborted_requests = (
|
||||
self._request_tracker.get_new_and_aborted_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
try:
|
||||
await self.engine.add_request_async(**new_request)
|
||||
except ValueError as e:
|
||||
# TODO: use a vLLM specific error for failed validation
|
||||
self._request_tracker.process_exception(
|
||||
new_request["request_id"],
|
||||
e,
|
||||
verbose=self.log_requests,
|
||||
)
|
||||
|
||||
if aborted_requests:
|
||||
await self._engine_abort(aborted_requests)
|
||||
|
||||
request_outputs = await self.engine.step_async(virtual_engine)
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
When request_outputs is None, it means prefill tasks are not finished.
|
||||
"""
|
||||
if request_outputs is None:
|
||||
return True
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
# Put the outputs into the corresponding streams.
|
||||
# If used as a callback, then already invoked inside
|
||||
# LLMEngine's _process_model_outputs
|
||||
if not self.use_process_request_outputs_callback:
|
||||
all_finished = self.process_request_outputs(request_outputs)
|
||||
else:
|
||||
# For callback case, we only need to detect when all
|
||||
# requests are finished
|
||||
all_finished = all(request_output.finished
|
||||
for request_output in request_outputs)
|
||||
|
||||
return not all_finished
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
_AsyncLLMEngine,
|
||||
_AsyncLLMEngine.__init__,
|
||||
vllm__engine__async_llm_engine___AsyncLLMEngine____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
_AsyncLLMEngine,
|
||||
_AsyncLLMEngine.step_async,
|
||||
vllm__engine__async_llm_engine___AsyncLLMEngine__step_async
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
AsyncLLMEngine,
|
||||
AsyncLLMEngine.engine_step,
|
||||
vllm__engine__async_llm_engine__AsyncLLMEngine__engine_step
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import distributed_gpu_executor
|
||||
from . import distributed_mlu_executor
|
||||
from . import ray_mlu_executor
|
||||
@@ -0,0 +1,75 @@
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.executor.distributed_gpu_executor import DistributedGPUExecutorAsync
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add two parameters, in which prefill_locs indicates the start location
|
||||
and token_chunk_sizes indicates the chunk size for each task.
|
||||
'''
|
||||
async def vllm__executor__distributed_gpu_executor__DistributedGPUExecutorAsync__execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return await self._driver_execute_model_async(execute_model_req, prefill_locs, token_chunk_sizes)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add two parameters, in which prefill_locs indicates the start location
|
||||
and token_chunk_sizes indicates the chunk size for each task.
|
||||
'''
|
||||
@abstractmethod
|
||||
async def vllm__executor__distributed_gpu_executor__DistributedGPUExecutorAsync___driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
"""Execute the model asynchronously in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
DistributedGPUExecutorAsync,
|
||||
DistributedGPUExecutorAsync.execute_model_async,
|
||||
vllm__executor__distributed_gpu_executor__DistributedGPUExecutorAsync__execute_model_async
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
DistributedGPUExecutorAsync,
|
||||
DistributedGPUExecutorAsync._driver_execute_model_async,
|
||||
vllm__executor__distributed_gpu_executor__DistributedGPUExecutorAsync___driver_execute_model_async
|
||||
)
|
||||
@@ -0,0 +1,75 @@
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.executor.distributed_mlu_executor import DistributedMLUExecutorAsync
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add two parameters, in which prefill_locs indicates the start location
|
||||
and token_chunk_sizes indicates the chunk size for each task.
|
||||
'''
|
||||
async def vllm__executor__distributed_mlu_executor__DistributedMLUExecutorAsync__execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return await self._driver_execute_model_async(execute_model_req, prefill_locs, token_chunk_sizes)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add two parameters, in which prefill_locs indicates the start location
|
||||
and token_chunk_sizes indicates the chunk size for each task.
|
||||
'''
|
||||
@abstractmethod
|
||||
async def vllm__executor__distributed_mlu_executor__DistributedMLUExecutorAsync___driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
"""Execute the model asynchronously in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
DistributedMLUExecutorAsync,
|
||||
DistributedMLUExecutorAsync.execute_model_async,
|
||||
vllm__executor__distributed_mlu_executor__DistributedMLUExecutorAsync__execute_model_async
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
DistributedMLUExecutorAsync,
|
||||
DistributedMLUExecutorAsync._driver_execute_model_async,
|
||||
vllm__executor__distributed_mlu_executor__DistributedMLUExecutorAsync___driver_execute_model_async
|
||||
)
|
||||
@@ -0,0 +1,175 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.executor.distributed_mlu_executor import DistributedMLUExecutorAsync
|
||||
from vllm.executor.ray_mlu_executor import RayMLUExecutorAsync
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
from ..lock_utils import (_run_task_with_priority_lock, PriorityLock)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutorAsync____init____org = RayMLUExecutorAsync.__init__
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutorAsync____init__(self, *args, **kwargs):
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutorAsync____init____org(self, *args, **kwargs)
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
For the prefill stage of a request in chunked pipeline parallel, tasks
|
||||
in the same pp_rank must be executed in order. Here, we use priority lock
|
||||
to implement this function.
|
||||
To ensure different requests executed in order, we will reserve a certain
|
||||
priority interval for each request. And the interval length is
|
||||
`max_model_len`, which is no less than the model execution rounds.
|
||||
And for each execution round, the priority is:
|
||||
`request_id * max_model_len + model_execution_time`
|
||||
"""
|
||||
self.priority = dict()
|
||||
self.priority_interval = self.model_config.max_model_len
|
||||
# To ensure pp tasks for the same prefill tokens are created atomically, we
|
||||
# use an extra lock to guard it.
|
||||
self.lock = asyncio.Lock()
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add two parameters, in which prefill_locs indicates the start location
|
||||
and token_chunk_sizes indicates the chunk size for each task.
|
||||
'''
|
||||
async def vllm__executor__ray_mlu_executor__RayMLUExecutorAsync__execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"RayMLUExecutorAsync is not supported for spmd mode.")
|
||||
return await DistributedMLUExecutorAsync.execute_model_async(
|
||||
self, execute_model_req, prefill_locs, token_chunk_sizes)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add two parameters, in which prefill_locs indicates the start location
|
||||
and token_chunk_sizes indicates the chunk size for each task.
|
||||
'''
|
||||
async def vllm__executor__ray_mlu_executor__RayMLUExecutorAsync___driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if not self.tp_driver_workers:
|
||||
return await self.driver_exec_method(
|
||||
"execute_model", execute_model_req, prefill_locs, token_chunk_sizes)
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
Use PriorityLock instead of lock to ensure that tasks in the same pp rank
|
||||
are executed with the dispatched order.
|
||||
"""
|
||||
request_id = 'dummy'
|
||||
update_priority_threshold = False
|
||||
is_prompt = False
|
||||
if execute_model_req is not None:
|
||||
assert len(execute_model_req.seq_group_metadata_list) == 1, \
|
||||
"Only single batch is supported for chunked pipeline parallel mode."
|
||||
request_id = execute_model_req.seq_group_metadata_list[0].request_id
|
||||
seq_group_metadata = execute_model_req.seq_group_metadata_list[0]
|
||||
request_priority = self.priority.setdefault(
|
||||
request_id, len(self.priority)*self.model_config.max_model_len)
|
||||
seq_data = list(seq_group_metadata.seq_data.values())[0]
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Update priority threshold to schedule next request.
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
if is_prompt and seq_len == prefill_locs[0] + token_chunk_sizes[0]:
|
||||
update_priority_threshold = True
|
||||
else:
|
||||
request_priority = -1
|
||||
|
||||
if self.pp_locks is None:
|
||||
# This locks each pipeline parallel stage so multiple virtual
|
||||
# engines can't execute on the same stage at the same time
|
||||
# We create the locks here to avoid creating them in the constructor
|
||||
# which uses a different asyncio loop.
|
||||
self.pp_locks = [
|
||||
PriorityLock(init_priority_threshold=self.model_config.max_model_len,
|
||||
priority_interval=self.priority_interval)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
async with self.lock:
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
_run_task_with_priority_lock(
|
||||
self.driver_exec_method, self.pp_locks[0], request_priority,
|
||||
update_priority_threshold,
|
||||
"execute_model", execute_model_req, prefill_locs, token_chunk_sizes,
|
||||
request_priority))
|
||||
]
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
||||
start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_priority_lock(
|
||||
driver_worker.execute_method.remote,
|
||||
self.pp_locks[pp_rank], request_priority,
|
||||
update_priority_threshold,
|
||||
"execute_model", execute_model_req, prefill_locs, token_chunk_sizes,
|
||||
request_priority)))
|
||||
if execute_model_req is not None:
|
||||
self.priority[request_id] += (token_chunk_sizes[0] if is_prompt else 1)
|
||||
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RayMLUExecutorAsync,
|
||||
RayMLUExecutorAsync.__init__,
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutorAsync____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
RayMLUExecutorAsync,
|
||||
RayMLUExecutorAsync.execute_model_async,
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutorAsync__execute_model_async
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
RayMLUExecutorAsync,
|
||||
RayMLUExecutorAsync._driver_execute_model_async,
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutorAsync___driver_execute_model_async
|
||||
)
|
||||
@@ -0,0 +1,218 @@
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PriorityLock:
|
||||
"""
|
||||
A lock class that prioritizes tasks based on their priority level and supports dynamic
|
||||
updating of priority thresholds after each lock release.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
_lock : asyncio.Lock
|
||||
An internal asyncio lock used to ensure mutual exclusion.
|
||||
_queue : asyncio.PriorityQueue
|
||||
A priority queue to store tasks by their priority. Tasks with lower numerical priority
|
||||
values have higher priority.
|
||||
_condition : asyncio.Condition
|
||||
A condition variable to manage the waiting and notification of tasks.
|
||||
_active_task : asyncio.Task or None
|
||||
Tracks the task currently holding the lock, or None if the lock is not held.
|
||||
_current_priority_threshold : int
|
||||
The current priority threshold for tasks allowed to acquire the lock.
|
||||
_priority_interval : int
|
||||
The value by which the priority threshold is incremented after a lock release when
|
||||
`update_priority_threshold` is enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, init_priority_threshold: int, priority_interval: int):
|
||||
"""
|
||||
Initializes a PriorityLock with an initial priority threshold and interval.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
init_priority_threshold : int
|
||||
The initial threshold for task priorities that can acquire the lock.
|
||||
priority_interval : int
|
||||
The interval by which the priority threshold increases after each lock release.
|
||||
"""
|
||||
self._lock = asyncio.Lock() # Internal asyncio lock
|
||||
self._queue = asyncio.PriorityQueue() # Priority queue to manage tasks by priority
|
||||
self._condition = asyncio.Condition() # Condition variable to manage waiting tasks
|
||||
self._active_task = None # Keep track of the current active task holding the lock
|
||||
self._current_priority_threshold = init_priority_threshold
|
||||
self._priority_interval = priority_interval
|
||||
|
||||
async def acquire(self, priority):
|
||||
"""
|
||||
Acquires the lock for a task based on its priority.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
priority : int
|
||||
The priority level of the task attempting to acquire the lock.
|
||||
|
||||
Behavior:
|
||||
---------
|
||||
- The task is enqueued based on its priority.
|
||||
- The task waits until it is the highest-priority task in the queue, has a priority
|
||||
below the current threshold, and the lock is available.
|
||||
"""
|
||||
queue_item = (priority, asyncio.current_task())
|
||||
async with self._condition:
|
||||
await self._queue.put(queue_item)
|
||||
|
||||
# Wait until the current task is the one with the highest priority and the lock is available
|
||||
while True:
|
||||
# Check if the current task is at the front of the queue and the lock is available
|
||||
current_priority, current_task = self._queue._queue[0] # Peek at the highest priority task
|
||||
if current_priority < self._current_priority_threshold and current_task is asyncio.current_task() and not self._lock.locked():
|
||||
await self._lock.acquire() # Acquire the lock
|
||||
self._active_task = current_task # Mark the current task as holding the lock
|
||||
await self._queue.get() # Remove the task from the queue
|
||||
break
|
||||
# If not the highest priority task, wait until notified
|
||||
await self._condition.wait()
|
||||
|
||||
async def release(self, update_priority_threshold):
|
||||
"""
|
||||
Releases the lock, optionally updating the priority threshold.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
update_priority_threshold : bool
|
||||
If True, increments the priority threshold by the configured interval.
|
||||
"""
|
||||
# Notify waiting tasks that the lock has been released
|
||||
async with self._condition:
|
||||
self._active_task = None # Clear the reference to the current task
|
||||
self._lock.release()
|
||||
|
||||
if update_priority_threshold:
|
||||
self._current_priority_threshold += self._priority_interval
|
||||
self._condition.notify_all() # Wake up all waiting tasks to recheck their priority
|
||||
|
||||
async def __aenter__(self, priority):
|
||||
"""
|
||||
Async context manager entry. Acquires the lock with the specified priority.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
priority : int
|
||||
The priority level of the task acquiring the lock.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
self : PriorityLock
|
||||
The lock instance.
|
||||
"""
|
||||
await self.acquire(priority)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb, update_priority_threshold):
|
||||
"""
|
||||
Async context manager exit. Releases the lock and optionally updates the priority threshold.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
exc_type : Exception or None
|
||||
The exception type, if any, raised in the 'async with' block.
|
||||
exc : Exception or None
|
||||
The exception instance, if any, raised in the 'async with' block.
|
||||
tb : traceback or None
|
||||
The traceback object, if any, associated with the exception.
|
||||
update_priority_threshold : bool
|
||||
If True, increments the priority threshold after releasing the lock.
|
||||
"""
|
||||
await self.release(update_priority_threshold) # Now release is async
|
||||
|
||||
|
||||
class PriorityLockManager:
|
||||
"""
|
||||
A helper class to manage the acquisition and release of a PriorityLock using an 'async with' block.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
_lock : PriorityLock
|
||||
The PriorityLock instance to be managed.
|
||||
_priority : int
|
||||
The priority level for the current task.
|
||||
_update_priority_threshold : bool
|
||||
Whether to update the priority threshold after the lock is released.
|
||||
"""
|
||||
|
||||
def __init__(self, lock, priority, update_priority_threshold):
|
||||
"""
|
||||
Initializes a PriorityLockManager with a PriorityLock and task-specific parameters.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
lock : PriorityLock
|
||||
The lock instance to manage.
|
||||
priority : int
|
||||
The priority level for the current task.
|
||||
update_priority_threshold : bool
|
||||
Whether to update the priority threshold after releasing the lock.
|
||||
"""
|
||||
self._lock = lock # The lock being managed
|
||||
self._priority = priority # The priority level for the current task
|
||||
self._update_priority_threshold = update_priority_threshold
|
||||
|
||||
async def __aenter__(self):
|
||||
"""
|
||||
Async context manager entry. Acquires the lock with the specified priority.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
lock : PriorityLock
|
||||
The lock instance that was acquired.
|
||||
"""
|
||||
await self._lock.acquire(self._priority) # Acquire the lock with priority
|
||||
return self._lock
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""
|
||||
Async context manager exit. Releases the lock and optionally updates the priority threshold.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
exc_type : Exception or None
|
||||
The exception type, if any, raised in the 'async with' block.
|
||||
exc : Exception or None
|
||||
The exception instance, if any, raised in the 'async with' block.
|
||||
tb : traceback or None
|
||||
The traceback object, if any, associated with the exception.
|
||||
"""
|
||||
await self._lock.__aexit__(exc_type, exc, tb, self._update_priority_threshold) # Release the lock
|
||||
|
||||
|
||||
async def _run_task_with_priority_lock(
|
||||
task: Callable, lock: asyncio.Lock, priority: int,
|
||||
update_priority_threshold: bool, *args, **kwargs):
|
||||
"""
|
||||
Runs a task within the context of a PriorityLock, ensuring proper acquisition and release.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task : Callable
|
||||
The async function representing the task to be executed.
|
||||
lock : PriorityLock
|
||||
The PriorityLock instance managing access.
|
||||
priority : int
|
||||
The priority level for the task.
|
||||
update_priority_threshold : bool
|
||||
Whether to update the priority threshold after releasing the lock.
|
||||
*args, **kwargs:
|
||||
Additional arguments to pass to the task function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
result : Any
|
||||
The result of the task execution.
|
||||
"""
|
||||
async with PriorityLockManager(lock, priority, update_priority_threshold): # Acquire the lock based on priority
|
||||
return await task(*args, **kwargs)
|
||||
@@ -0,0 +1,14 @@
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from . import distributed
|
||||
from . import engine
|
||||
from . import executor
|
||||
from . import model_executor
|
||||
from . import worker
|
||||
|
||||
logger.info("Apply Chunked Pipeline Parallel Demo!")
|
||||
@@ -0,0 +1,2 @@
|
||||
# hijack vllm models
|
||||
from .models import custom, llama
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import TensorMetadata
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.custom_model.custom import CustomForCausalLM
|
||||
|
||||
def vllm__module_executor__models__custom_model__CustomForCausalLM__get_intermediate_tensor_metadata(
|
||||
self,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> List[Tuple[str, Any]]:
|
||||
metadata_list: List[Tuple[str, Any]] = []
|
||||
size = torch.Size([batch_size, self.config.hidden_size])
|
||||
metadata_list.append(("hidden_states", TensorMetadata(device.type, dtype, size)))
|
||||
metadata_list.append(("residual", None))
|
||||
return metadata_list
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
CustomForCausalLM,
|
||||
"get_intermediate_tensor_metadata",
|
||||
vllm__module_executor__models__custom_model__CustomForCausalLM__get_intermediate_tensor_metadata
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import TensorMetadata
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
def vllm__module_executor__models__llama__LlamaForCausalLM__get_intermediate_tensor_metadata(
|
||||
self,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> List[Tuple[str, Any]]:
|
||||
metadata_list: List[Tuple[str, Any]] = []
|
||||
size = torch.Size([batch_size, self.config.hidden_size])
|
||||
metadata_list.append(("hidden_states", TensorMetadata(device.type, dtype, size)))
|
||||
return metadata_list
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LlamaForCausalLM,
|
||||
"get_intermediate_tensor_metadata",
|
||||
vllm__module_executor__models__llama__LlamaForCausalLM__get_intermediate_tensor_metadata
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import mlu_model_runner
|
||||
from . import model_runner
|
||||
from . import worker_base
|
||||
@@ -0,0 +1,176 @@
|
||||
import weakref
|
||||
from typing import (List, Optional)
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (
|
||||
TModelInputForGPU,
|
||||
LORA_WARMUP_RANK,
|
||||
_BATCH_SIZES_TO_CAPTURE
|
||||
)
|
||||
from vllm.worker.mlu_model_runner import (
|
||||
MLUModelRunnerBase,
|
||||
ModelInputForMLUBuilder
|
||||
)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def vllm__worker__mlu_model_runner__MLUModelRunnerBase__profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
# This represents the maximum number of different requests
|
||||
# that will have unique loras, an therefore the max amount of memory
|
||||
# consumption create dummy lora request copies from the lora request
|
||||
# passed in, which contains a lora from the lora warmup path.
|
||||
dummy_lora_requests: List[LoRARequest] = []
|
||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||
if self.lora_config:
|
||||
assert self.lora_manager is not None
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
for idx in range(self.lora_config.max_loras):
|
||||
lora_id = idx + 1
|
||||
dummy_lora_request = LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||
rank=LORA_WARMUP_RANK)
|
||||
dummy_lora_requests.append(dummy_lora_request)
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
if max_num_seqs < 1:
|
||||
expr = (f"min({max_num_seqs_orig}, "
|
||||
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||
logger.warning(
|
||||
"Computed max_num_seqs (%s) to be less than 1. "
|
||||
"Setting it to the minimum value of 1.", expr)
|
||||
max_num_seqs = 1
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: dummy_data.seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||
if dummy_lora_requests_per_seq else None,
|
||||
multi_modal_data=dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
# it is important to create tensors inside the loop, rather than
|
||||
# multiplying the list, to avoid Dynamo from treating them as
|
||||
# tensor aliasing.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support kv cache int8
|
||||
'''
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache_ = torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
kv_cache_scale_ = torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
kv_caches.append([kv_cache_, kv_cache_scale_])
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: Add two parameters: prefill_loc and token_chunk_size.
|
||||
"""
|
||||
token_chunk_sizes = [seq.token_chunk_size for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs,
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
prefill_locs=[0]*len(seqs),
|
||||
token_chunk_sizes=token_chunk_sizes,
|
||||
)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
if self.model_config.enforce_eager:
|
||||
batch_size_capture_list = []
|
||||
with set_compile_context(batch_size_capture_list):
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.mlu.synchronize()
|
||||
|
||||
return
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
MLUModelRunnerBase,
|
||||
MLUModelRunnerBase.profile_run,
|
||||
vllm__worker__mlu_model_runner__MLUModelRunnerBase__profile_run
|
||||
)
|
||||
@@ -0,0 +1,304 @@
|
||||
import dataclasses
|
||||
import weakref
|
||||
from typing import (List, Optional, TypeVar)
|
||||
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (
|
||||
GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner,
|
||||
TModelInputForGPU
|
||||
)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: Add two parameters, prefill_loc and token_chunk_size.
|
||||
"""
|
||||
def vllm__worker__model_runner__ModelInputForGPUBuilder___compute_lens(
|
||||
self, inter_data: ModelInputForGPUBuilder.InterDataForSeqGroup,
|
||||
seq_idx: int, seq_group_metadata: SequenceGroupMetadata,
|
||||
prefill_loc: Optional[int] = None,
|
||||
token_chunk_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
"""Compute context length, sequence length and tokens
|
||||
for the given sequence data.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
|
||||
if token_chunk_size is None:
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
|
||||
# Compute context length (the number of tokens that are
|
||||
# already computed) and sequence length (total number of tokens).
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
if inter_data.is_prompt:
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: For chunked pipeline parallel, since multiple tasks
|
||||
use the same sequence data with different prefill location,
|
||||
an extra parameter is provided to indicate the prefill location.
|
||||
"""
|
||||
context_len = (
|
||||
prefill_loc if prefill_loc is not None
|
||||
else seq_data.get_num_computed_tokens()
|
||||
)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
elif self.runner.scheduler_config.is_multi_step or \
|
||||
self.runner.model_config.is_encoder_decoder:
|
||||
assert prefill_loc is None, "Chunked Parallel Pipeline does not support multi-step."
|
||||
context_len = seq_len - 1
|
||||
else:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
# Compute tokens.
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.query_lens[seq_idx] = seq_len - context_len
|
||||
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
if inter_data.mrope_input_positions is None:
|
||||
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||
|
||||
inter_data.mrope_input_positions[
|
||||
seq_idx] = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: Add two parameters, prefill_loc and token_chunk_size.
|
||||
"""
|
||||
def vllm__worker__model_runner__ModelInputForGPUBuilder__add_seq_group(
|
||||
self, seq_group_metadata: SequenceGroupMetadata,
|
||||
prefill_loc: Optional[int] = None,
|
||||
token_chunk_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
"""Add a sequence group to the builder."""
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
n_seqs = len(seq_ids)
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
if is_prompt:
|
||||
assert n_seqs == 1
|
||||
self.decode_only = False
|
||||
|
||||
encoder_seq_len = 0
|
||||
|
||||
if self.runner.model_config.is_encoder_decoder:
|
||||
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
|
||||
|
||||
inter_data = self.init_cached_inter_data(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
seq_ids=seq_ids,
|
||||
is_prompt=is_prompt,
|
||||
block_tables=seq_group_metadata.block_tables,
|
||||
computed_block_nums=seq_group_metadata.computed_block_nums,
|
||||
reinit=True,
|
||||
reinit_use_defaults=True,
|
||||
encoder_seq_len=encoder_seq_len)
|
||||
|
||||
self.inter_data_list.append(inter_data)
|
||||
|
||||
for seq_idx in range(n_seqs):
|
||||
for per_seq_fn in self.per_seq_compute_fns:
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: Add prefill location and token chunk size parameters.
|
||||
"""
|
||||
if per_seq_fn.__qualname__ == \
|
||||
"vllm__worker__model_runner__ModelInputForGPUBuilder___compute_lens":
|
||||
per_seq_fn(inter_data, seq_idx, seq_group_metadata, prefill_loc, token_chunk_size)
|
||||
else:
|
||||
per_seq_fn(inter_data, seq_idx, seq_group_metadata)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
for per_seq_group_fn in self.per_seq_group_compute_fns:
|
||||
per_seq_group_fn(inter_data, seq_group_metadata)
|
||||
|
||||
|
||||
def vllm__worker__model_runner__GPUModelRunnerBase___prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> TModelInputForGPU:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
The API assumes seq_group_metedata_list is sorted by prefill -> decode.
|
||||
|
||||
The result tensors and data structure also batches input in prefill
|
||||
-> decode order. For example,
|
||||
|
||||
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: Add two parameters: prefill_loc and token_chunk_size, and
|
||||
check whether they are same as sequence group length or empty.
|
||||
"""
|
||||
if prefill_locs is None:
|
||||
prefill_locs = [None] * len(seq_group_metadata_list)
|
||||
|
||||
assert len(prefill_locs) == len(seq_group_metadata_list), \
|
||||
"the lengths of prefill locs and seq_group_metadata are different."
|
||||
|
||||
if token_chunk_sizes is None:
|
||||
token_chunk_sizes = [None] * len(seq_group_metadata_list)
|
||||
|
||||
assert len(token_chunk_sizes) == len(seq_group_metadata_list), \
|
||||
"the lengths of token_chunk_sizes and seq_group_metadata are different."
|
||||
|
||||
for seq_group_metadata, prefill_loc, token_chunk_size in zip(
|
||||
seq_group_metadata_list, prefill_locs, token_chunk_sizes
|
||||
):
|
||||
builder.add_seq_group(seq_group_metadata, prefill_loc, token_chunk_size)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
|
||||
builder.reset_cached_inter_data()
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: Add two parameters, prefill_loc and token_chunk_size.
|
||||
"""
|
||||
def vllm__worker__model_runner__ModelRunner__prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[List[int]] = None,
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||
|
||||
The result tensors and data structure also batches input in prefill
|
||||
-> decode order. For example,
|
||||
|
||||
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
Add prefill location parameter.
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list,
|
||||
finished_requests_ids,
|
||||
prefill_locs,
|
||||
token_chunk_sizes)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
if get_pp_group().is_last_rank:
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, model_input.seq_lens,
|
||||
model_input.query_lens, self.device, self.pin_memory,
|
||||
generators, self.sampling_metadata_cache)
|
||||
else:
|
||||
sampling_metadata = None
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prompt=is_prompt,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUBuilder._compute_lens,
|
||||
vllm__worker__model_runner__ModelInputForGPUBuilder___compute_lens
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUBuilder.add_seq_group,
|
||||
vllm__worker__model_runner__ModelInputForGPUBuilder__add_seq_group
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
GPUModelRunnerBase,
|
||||
GPUModelRunnerBase._prepare_model_input_tensors,
|
||||
vllm__worker__model_runner__GPUModelRunnerBase___prepare_model_input_tensors
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelRunner,
|
||||
ModelRunner.prepare_model_input,
|
||||
vllm__worker__model_runner__ModelRunner__prepare_model_input
|
||||
)
|
||||
@@ -0,0 +1,219 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors)
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
update_environment_variables)
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
WorkerInput,
|
||||
extract_previous_hidden_states)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[int] = None,
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
Pass prefill location and chunk size parameters.
|
||||
"""
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids,
|
||||
prefill_locs,
|
||||
token_chunk_sizes))
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data.update(kwargs)
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[int] = None,
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
Pass prefill location and chunk size parameters.
|
||||
"""
|
||||
return self._get_driver_input_and_broadcast(
|
||||
execute_model_req, prefill_locs, token_chunk_sizes)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
else:
|
||||
return self._get_worker_input_from_broadcast()
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase__execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
prefill_locs: Optional[List[int]] = None,
|
||||
token_chunk_sizes: Optional[int] = None,
|
||||
priority: int = -1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
Pass prefill location and chunk size parameters.
|
||||
"""
|
||||
inputs = self.prepare_input(execute_model_req, prefill_locs, token_chunk_sizes)
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
model_input, worker_input, kwargs = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
"""
|
||||
======================================
|
||||
Modified by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
@brief: To prevent the execution of mlu pipeline interrupted by host communication,
|
||||
cancel the host communication and prepare metadata list directly.
|
||||
"""
|
||||
assert (token_chunk_sizes is not None and len(token_chunk_sizes) == 1)
|
||||
batch_size = token_chunk_sizes[0]
|
||||
metadata_list = self.model_runner.model.get_intermediate_tensor_metadata(
|
||||
batch_size,
|
||||
dtype=self.model_runner.model_config.dtype,
|
||||
device=self.model_runner.device)
|
||||
|
||||
intermediate_tensors = None
|
||||
orig_model_execute_time = 0.0
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
recv_metadata_list=metadata_list))
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
orig_model_execute_time = intermediate_tensors.tensors.get(
|
||||
"model_execute_time", torch.tensor(0)).item()
|
||||
"""
|
||||
======================================
|
||||
End by Chunked Parallel Pipeline.
|
||||
======================================
|
||||
"""
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
num_steps=num_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
model_execute_time + orig_model_execute_time)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return [None]
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time
|
||||
and output is not None):
|
||||
for o in output:
|
||||
o.model_execute_time = (orig_model_execute_time +
|
||||
model_execute_time)
|
||||
|
||||
# output is List[SamplerOutput]
|
||||
return output
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase.prepare_input,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase._get_driver_input_and_broadcast,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase.execute_model,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase__execute_model
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
### 简介
|
||||
|
||||
该example是vLLM中进行Context Parallel和Ring Attention的实验,mlu_hijack是对仓库代码的劫持,避免修改主仓库代码
|
||||
|
||||
### 支持模型
|
||||
|
||||
目前仅对LLaMA2系列模型进行了精度验证
|
||||
|
||||
### 支持板卡
|
||||
|
||||
暂不支持300系列设备
|
||||
|
||||
### 运行demo
|
||||
```python
|
||||
python examples/cambricon_custom_func/context_parallel/offline_inference.py
|
||||
```
|
||||
|
||||
### 使用Context Parallel特性
|
||||
|
||||
设置环境变量export CONTEXT_PARALLEL_EN=1|True|true|TRUE, LLM主接口传入context_parallel_size参数
|
||||
|
||||
### 实现细节
|
||||
|
||||
- 为了使Ring Attention实现负载均衡,数据使用了zigzag的拆分方式
|
||||
- 需要的MLU卡数为world_size = context_parallel_size * tensor_parallel_size,先拆cp, 然后拆tp
|
||||
- 目前只是用作实验验证,context阶段采用cp,decoder阶段只在一个cp group上进行
|
||||
- 支持kv cache int8量化
|
||||
@@ -0,0 +1,83 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
import argparse
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
help="support /data/AE/llm/models/Llama-2-7b-hf/, \
|
||||
/data/AE/llm/models/Llama-2-13b-hf/, \
|
||||
/data/AE/llm/models/Llama-2-70b-hf/")
|
||||
parser.add_argument('--input_len', type=int, default=4096)
|
||||
parser.add_argument('--output_len', type=int, default=1)
|
||||
parser.add_argument("--tensor_parallel_size", "-tp", type=int, help="tp")
|
||||
parser.add_argument("--context_parallel_size", "-cp", type=int, help="cp")
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=None)
|
||||
parser.add_argument('--num_iters_warmup',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of iterations to run for warmup.')
|
||||
parser.add_argument('--num_iters',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust_remote_code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument('--latency',
|
||||
action='store_true',
|
||||
help='get context latency')
|
||||
args = parser.parse_args()
|
||||
|
||||
print("model: ", args.model)
|
||||
print("seq_len: ", args.input_len)
|
||||
print("tensor_parallel_size: ", args.tensor_parallel_size)
|
||||
print("context_parallel_size: ", args.context_parallel_size)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, max_tokens=args.output_len)
|
||||
llm = LLM(model=args.model, enforce_eager=True, max_model_len = args.input_len,
|
||||
max_num_batched_tokens = args.input_len, max_num_seqs = 1,
|
||||
tensor_parallel_size = args.tensor_parallel_size,
|
||||
context_parallel_size = args.context_parallel_size)
|
||||
|
||||
np.random.seed(0)
|
||||
dummy_prompt_token_ids = np.random.randint(10000, size=(1, args.input_len))
|
||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
||||
|
||||
if args.latency:
|
||||
def run_to_completion():
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion()
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion())
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f'{percentage}% percentile latency: {percentile} seconds')
|
||||
llm.get_metrics(args.num_iters_warmup,False,args.input_len,args.output_len,args.tensor_parallel_size,args.quantization)
|
||||
else:
|
||||
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, sampling_params = sampling_params)
|
||||
@@ -0,0 +1 @@
|
||||
from .backends import mlu_attn
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import Optional, Type
|
||||
import torch
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
|
||||
from vllm_mlu.attention.backends.mlu_attn import MLUFlashAttentionImpl_V2
|
||||
|
||||
from .ring_attn import zigzag_ring_attn
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_world_size)
|
||||
|
||||
|
||||
vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_org = MLUFlashAttentionImpl_V2.forward
|
||||
|
||||
def vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_wraper(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: MLUFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: use ring attn when context parallel
|
||||
'''
|
||||
if get_context_model_parallel_world_size() > 1 and attn_metadata.prefill_metadata:
|
||||
return zigzag_ring_attn(self,
|
||||
query=query.view(-1, self.num_heads, self.head_size),
|
||||
key=key.view(-1, self.num_kv_heads, self.head_size),
|
||||
value=value.view(-1, self.num_kv_heads, self.head_size),
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
return vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_org(self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionImpl_V2,
|
||||
MLUFlashAttentionImpl_V2.forward,
|
||||
vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_wraper)
|
||||
@@ -0,0 +1,216 @@
|
||||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import get_context_model_parallel_group
|
||||
from ...distributed.ring_comm import RingComm
|
||||
|
||||
|
||||
# code references: https://github.com/zhuzilin/ring-flash-attention
|
||||
def _update_out_and_lse(
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
block_out = block_out.to(torch.float32)
|
||||
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
||||
lse = lse - F.logsigmoid(lse - block_lse)
|
||||
return out, lse
|
||||
|
||||
|
||||
def update_out_and_lse(
|
||||
out: Optional[torch.Tensor],
|
||||
lse: Optional[torch.Tensor],
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
slice_=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if out is None:
|
||||
if slice_ is not None:
|
||||
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
|
||||
out = block_out.to(torch.float32)
|
||||
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
elif slice_ is not None:
|
||||
slice_out, slice_lse = out[slice_], lse[slice_]
|
||||
slice_out, slice_lse = _update_out_and_lse(
|
||||
slice_out, slice_lse, block_out, block_lse
|
||||
)
|
||||
out[slice_], lse[slice_] = slice_out, slice_lse
|
||||
else:
|
||||
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
|
||||
return out, lse
|
||||
|
||||
|
||||
def get_half(pack_tensor, cu_seq_lens, first_half):
|
||||
batch_num = cu_seq_lens.shape[0] - 1
|
||||
half_list = []
|
||||
for batch in range(batch_num):
|
||||
if first_half:
|
||||
start = cu_seq_lens[batch]
|
||||
end = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
else:
|
||||
start = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
end = cu_seq_lens[batch + 1]
|
||||
half = pack_tensor[start: end]
|
||||
half_list.append(half)
|
||||
half = torch.cat(half_list, dim=0)
|
||||
return half
|
||||
|
||||
|
||||
def update_half(pack_tensor, half_tensor, cu_seq_lens, first_half):
|
||||
half_cu_seq_lens = cu_seq_lens // 2
|
||||
batch_num = cu_seq_lens.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
if first_half:
|
||||
start = cu_seq_lens[batch]
|
||||
end = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
else:
|
||||
start = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
end = cu_seq_lens[batch + 1]
|
||||
pack_tensor[start: end] = half_tensor[half_cu_seq_lens[batch]: half_cu_seq_lens[batch + 1]]
|
||||
|
||||
|
||||
def zigzag_ring_attn(self,
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads. head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: MLUFlashAttentionMetadata) -> torch.Tensor:
|
||||
num_tokens, _, _ = query.shape
|
||||
cu_seq_lens = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = cu_seq_lens.shape[0] - 1
|
||||
block_seq_len = query.shape[0] // 2
|
||||
process_group = get_context_model_parallel_group().device_group
|
||||
comm = RingComm(process_group) # k
|
||||
comm_ = RingComm(process_group) # v
|
||||
comm__ = RingComm(process_group) # slot_mapping
|
||||
|
||||
q, k, v = query, key, value
|
||||
if batch_num == 1:
|
||||
q1 = q[block_seq_len:]
|
||||
else:
|
||||
q1 = get_half(q, cu_seq_lens, False)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
out = None
|
||||
lse = None
|
||||
next_k, next_v = None, None
|
||||
next_slot_mapping = None
|
||||
|
||||
def forward(q, k, v, causal):
|
||||
if batch_num == 1:
|
||||
seq = q.shape[0]
|
||||
seq_k = k.shape[0]
|
||||
cu_seq_lens_q = torch.arange(0, seq+1, seq, dtype=torch.int32, device=q.device)
|
||||
cu_seq_lens_kv = torch.arange(0, seq_k+1, seq_k, dtype=torch.int32, device=q.device)
|
||||
max_seq_len_q = seq
|
||||
max_seq_len_kv = seq_k
|
||||
else:
|
||||
max_seq_len_q = attn_metadata.prefill_metadata.max_seq_len
|
||||
max_seq_len_kv = attn_metadata.prefill_metadata.max_seq_len
|
||||
cu_seq_lens_q = cu_seq_lens
|
||||
cu_seq_lens_kv = cu_seq_lens
|
||||
if q.shape[0] != cu_seq_lens[-1]:
|
||||
cu_seq_lens_q = cu_seq_lens // 2
|
||||
max_seq_len_q = max_seq_len_q // 2
|
||||
if k.shape[0] != cu_seq_lens[-1]:
|
||||
cu_seq_lens_kv = cu_seq_lens // 2
|
||||
max_seq_len_kv = max_seq_len_kv // 2
|
||||
alibi_slopes = None if self.alibi_slopes is None else \
|
||||
self.alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
||||
ouptuts = mlu_ops.flash_attention(q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_kv,
|
||||
alibi_slopes,
|
||||
None,
|
||||
max_seq_len_q,
|
||||
max_seq_len_kv,
|
||||
self.scale,
|
||||
causal, -1, -1, torch.float, True)
|
||||
block_out, block_lse = ouptuts[0], ouptuts[1]
|
||||
|
||||
if block_lse.shape[0] == 1:
|
||||
block_lse = block_lse[0]
|
||||
else:
|
||||
# block_lse shape is [batch, head_num_q, max_seq_q], the empty part will set 0
|
||||
# we need to modify the shape to [batch, head_num_q, total_seq_q]
|
||||
block_lse_list = []
|
||||
for batch in range(block_lse.shape[0]):
|
||||
block_lse_ = block_lse[batch][:, : cu_seq_lens_q[batch + 1] - cu_seq_lens_q[batch]]
|
||||
block_lse_list.append(block_lse_)
|
||||
block_lse = torch.cat(block_lse_list, dim=-1)
|
||||
|
||||
return block_out, block_lse
|
||||
|
||||
for step in range(comm.world_size):
|
||||
if step + 1 != comm.world_size:
|
||||
next_k: torch.Tensor = comm.send_recv(k.contiguous())
|
||||
next_v: torch.Tensor = comm_.send_recv(v.contiguous())
|
||||
next_slot_mapping: torch.Tensor = comm__.send_recv(slot_mapping)
|
||||
comm.commit()
|
||||
comm_.commit()
|
||||
comm__.commit()
|
||||
|
||||
# call mlu_ops.reshape_paged_cache
|
||||
if kv_cache[0].numel() > 0:
|
||||
kv_cache_, kv_cache_scale_ = kv_cache
|
||||
key_cache, value_cache = kv_cache_[0], kv_cache_[1]
|
||||
if isinstance(kv_cache[0], torch.Tensor) and kv_cache[0].dtype == torch.int8:
|
||||
key_cache_scale, value_cache_scale = kv_cache_scale_[0], kv_cache_scale_[1]
|
||||
mlu_ops.quant_to_paged_cache(k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten())
|
||||
|
||||
if step == 0:
|
||||
block_out, block_lse = forward(q, k, v, causal = True)
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
|
||||
elif step <= comm.rank:
|
||||
if batch_num == 1:
|
||||
k0 = k[:block_seq_len]
|
||||
v0 = v[:block_seq_len]
|
||||
else:
|
||||
k0 = get_half(k, cu_seq_lens, True)
|
||||
v0 = get_half(v, cu_seq_lens, True)
|
||||
block_out, block_lse = forward(q, k0, v0, causal = False)
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
|
||||
else:
|
||||
block_out, block_lse = forward(q1, k, v, causal = False)
|
||||
if batch_num == 1:
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse,
|
||||
slice_=(slice(block_seq_len, None)),)
|
||||
else:
|
||||
slice_out = get_half(out, cu_seq_lens, False)
|
||||
slice_lse = get_half(lse, cu_seq_lens, False)
|
||||
slice_out, slice_lse = update_out_and_lse(
|
||||
slice_out, slice_lse, block_out, block_lse
|
||||
)
|
||||
update_half(out, slice_out, cu_seq_lens, False)
|
||||
update_half(lse, slice_lse, cu_seq_lens, False)
|
||||
|
||||
if step + 1 != comm.world_size:
|
||||
comm.wait()
|
||||
comm_.wait()
|
||||
comm__.wait()
|
||||
k = next_k
|
||||
v = next_v
|
||||
slot_mapping = next_slot_mapping
|
||||
out = out.to(q.dtype)
|
||||
return out.view(num_tokens, self.num_heads * self.head_size)
|
||||
@@ -0,0 +1 @@
|
||||
from . import ring_comm
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# code references: https://github.com/zhuzilin/ring-flash-attention
|
||||
class RingComm:
|
||||
def __init__(self, process_group: dist.ProcessGroup):
|
||||
self._process_group = process_group
|
||||
self._ops = []
|
||||
self.rank = dist.get_rank(self._process_group)
|
||||
self.world_size = dist.get_world_size(self._process_group)
|
||||
self._reqs = None
|
||||
|
||||
self.send_rank = (self.rank + 1) % self.world_size
|
||||
self.recv_rank = (self.rank - 1) % self.world_size
|
||||
|
||||
if process_group is not None:
|
||||
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
|
||||
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
|
||||
|
||||
def send_recv(
|
||||
self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if recv_tensor is None:
|
||||
res = torch.empty_like(to_send)
|
||||
else:
|
||||
res = recv_tensor
|
||||
|
||||
send_op = dist.P2POp(
|
||||
dist.isend, to_send, self.send_rank, group=self._process_group
|
||||
)
|
||||
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
|
||||
self._ops.append(send_op)
|
||||
self._ops.append(recv_op)
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
if self._reqs is not None:
|
||||
raise RuntimeError("commit called twice")
|
||||
self._reqs = dist.batch_isend_irecv(self._ops)
|
||||
|
||||
def wait(self):
|
||||
if self._reqs is None:
|
||||
raise RuntimeError("wait called before commit")
|
||||
for req in self._reqs:
|
||||
req.wait()
|
||||
self._reqs = None
|
||||
self._ops = []
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import gpu_executor
|
||||
from . import ray_mlu_executor
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
def vllm__executor__gpu_executor__GPUExecutor___get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return worker init args for a given rank."""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: replace self.parallel_config.tensor_parallel_size with self.parallel_config.world_size.
|
||||
'''
|
||||
return dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.world_size == 0),
|
||||
)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
GPUExecutor,
|
||||
GPUExecutor._get_worker_kwargs,
|
||||
vllm__executor__gpu_executor__GPUExecutor___get_worker_kwargs)
|
||||
@@ -0,0 +1,246 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id)
|
||||
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG, VLLM_LATENCY_DEBUG_NO_DEVICE
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
from vllm.executor.ray_mlu_executor import RayMLUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray(
|
||||
self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
and self.parallel_config.pipeline_parallel_size == 1):
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
else:
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
worker_ips = [
|
||||
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
|
||||
for worker in self.workers
|
||||
]
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = ray.get(worker.get_node_ip.remote())
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||
use_dummy_driver=True)
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP` or "
|
||||
"`HOST_IP` environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"MLU_VISIBLE_DEVICES":
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
**({
|
||||
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
|
||||
} if envs.VLLM_ATTENTION_BACKEND is not None else {}),
|
||||
"VLLM_LATENCY_DEBUG":
|
||||
'1' if VLLM_LATENCY_DEBUG else '0',
|
||||
"VLLM_LATENCY_DEBUG_NO_DEVICE":
|
||||
'1' if VLLM_LATENCY_DEBUG_NO_DEVICE else '0',
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
self._env_vars_for_all_workers = (
|
||||
all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: replace tp size with world_size.
|
||||
'''
|
||||
if rank % self.parallel_config.world_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(RayMLUExecutor,
|
||||
RayMLUExecutor._init_workers_ray,
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray)
|
||||
@@ -0,0 +1,6 @@
|
||||
print("Apply Context Parallel Demo!")
|
||||
from . import distributed
|
||||
from . import attention
|
||||
from . import model_executor
|
||||
from . import worker
|
||||
from . import executor
|
||||
@@ -0,0 +1,2 @@
|
||||
from .layers import rotary_embedding
|
||||
from .layers import logits_processor
|
||||
@@ -0,0 +1,110 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import vllm
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_world_group
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states, _apply_logits_processors
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_world_size, get_context_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
|
||||
|
||||
def vllm__module_executor__layers__logits_processor__LogitsProcessor__forward_wraper(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: context parallel requires special handling of hidden_states and logits
|
||||
'''
|
||||
if self.attn_metadata and get_context_model_parallel_world_size() > 1:
|
||||
hidden_states = _prune_hidden_states_context_parallel(hidden_states, sampling_metadata, self.attn_metadata)
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.soft_cap
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
if sampling_metadata is not None:
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: token num can be divisible by context_parallel_size * 2 after padding,
|
||||
and then split to context parallel groups with zigzag method, now we
|
||||
need to find the last valid tokens, and get the logits for the next tokens.
|
||||
'''
|
||||
def _prune_hidden_states_context_parallel(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
attn_metadata: AttentionMetadata
|
||||
) -> torch.Tensor:
|
||||
select_hidden_states_list = []
|
||||
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = seq_start_loc.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
start = seq_start_loc[batch]
|
||||
end = seq_start_loc[batch + 1]
|
||||
hidden_states_ = hidden_states[start : end]
|
||||
split_seq_len = hidden_states_.shape[0] // 2
|
||||
seq_len = attn_metadata.prefill_metadata.seq_lens[batch]
|
||||
last_id = seq_len - 1
|
||||
idx = last_id // split_seq_len
|
||||
select_hidden_states = torch.zeros((1, hidden_states.shape[-1]), dtype = hidden_states.dtype, device = hidden_states.device)
|
||||
if idx < get_context_model_parallel_world_size():
|
||||
target_cp_id = idx
|
||||
src_rank = get_tensor_model_parallel_world_size() * target_cp_id
|
||||
if get_context_model_parallel_rank() == target_cp_id:
|
||||
selected_token_indices = last_id - idx * split_seq_len
|
||||
select_hidden_states = hidden_states_[selected_token_indices].unsqueeze(0)
|
||||
else:
|
||||
target_cp_id = get_context_model_parallel_world_size() * 2 - 1 - idx
|
||||
src_rank = get_tensor_model_parallel_world_size() * target_cp_id
|
||||
if get_context_model_parallel_rank() == target_cp_id:
|
||||
selected_token_indices = last_id - idx * split_seq_len + split_seq_len
|
||||
select_hidden_states = hidden_states_[selected_token_indices].unsqueeze(0)
|
||||
|
||||
select_hidden_states = get_world_group().broadcast(select_hidden_states, src = src_rank)
|
||||
select_hidden_states_list.append(select_hidden_states)
|
||||
|
||||
select_hidden_states = torch.cat(select_hidden_states_list, dim=0)
|
||||
return select_hidden_states
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LogitsProcessor,
|
||||
LogitsProcessor.forward,
|
||||
vllm__module_executor__layers__logits_processor__LogitsProcessor__forward_wraper)
|
||||
@@ -0,0 +1,62 @@
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import vllm
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding import MLURotaryEmbedding
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_world_size)
|
||||
|
||||
def vllm__module_executor__layers__rotary_embedding__MLURotaryEmbedding__forward_mlu_wraper(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
raise ValueError(f"tmo.apply_rotary not support offsets yet.")
|
||||
else:
|
||||
if MLURotaryEmbedding.set_cos_sin == False:
|
||||
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
|
||||
MLURotaryEmbedding.set_cos_sin = True
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else :
|
||||
position_ids = None
|
||||
discrete = False
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: context parallel need discrete = True
|
||||
'''
|
||||
position_ids = None if (MLURotaryEmbedding.is_prompt and get_context_model_parallel_world_size == 1) else positions
|
||||
discrete = False if (MLURotaryEmbedding.is_prompt and get_context_model_parallel_world_size == 1) else True
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
x = mlu_ops.rotary_embedding(x,
|
||||
MLURotaryEmbedding.sin_,
|
||||
MLURotaryEmbedding.cos_,
|
||||
position_ids,
|
||||
MLURotaryEmbedding.cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLURotaryEmbedding,
|
||||
MLURotaryEmbedding.forward_mlu,
|
||||
vllm__module_executor__layers__rotary_embedding__MLURotaryEmbedding__forward_mlu_wraper)
|
||||
@@ -0,0 +1,5 @@
|
||||
from . import mlu_model_runner
|
||||
from . import model_runner
|
||||
from . import model_runner_base
|
||||
from . import worker
|
||||
from . import worker_base
|
||||
@@ -0,0 +1,256 @@
|
||||
import torch
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm.worker.model_runner import (
|
||||
TModelInputForGPU, ModelInputForGPU,
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelInputForGPUBuilder, GPUModelRunnerBase,
|
||||
ModelRunner, CUDAGraphRunner,
|
||||
LORA_WARMUP_RANK, _get_graph_batch_size,
|
||||
_BATCH_SIZES_TO_CAPTURE, _NUM_WARMUP_ITERS
|
||||
)
|
||||
from vllm.worker.mlu_model_runner import MLUModelRunner
|
||||
from vllm.sequence import (IntermediateTensors, SequenceGroupMetadata)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from ..zigzag_utils import get_context_model_parallel_world_size, zigzag_split
|
||||
import vllm.envs as envs
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
@torch.inference_mode()
|
||||
def vllm__worker__mlu_model_runner__MLUModelRunner__execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.mlu.Event(enable_timing=True)
|
||||
model_forward_end = torch.mlu.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add mlu metrics
|
||||
'''
|
||||
# Add time markers for model_executable+compute_logits
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
use_cuda_graph = ((prefill_meta is None and decode_meta.use_cuda_graph)
|
||||
or use_context_mlugraph)
|
||||
# if use_cuda_graph, the start timestamp will be inserted inside MLUGraphRunner.forward()
|
||||
if not use_cuda_graph:
|
||||
start = torch.mlu.Event(enable_timing=True)
|
||||
start.record()
|
||||
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: context parallel split input for model with zigzag method
|
||||
'''
|
||||
if get_context_model_parallel_world_size() > 1 and model_input.attn_metadata.prefill_metadata:
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
zigzag_input_ids, zigzag_positions, zigzag_attn_metadata = zigzag_split(model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata, _PAD_SLOT_ID)
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=zigzag_input_ids,
|
||||
positions=zigzag_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=zigzag_attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**multi_modal_kwargs,
|
||||
**seqlen_agnostic_kwargs)
|
||||
else:
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
#################################################################################################
|
||||
# DEBUG #
|
||||
#################################################################################################
|
||||
# import os
|
||||
# from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
|
||||
# from from examples.cambricon_custom_funcvllm.mlu_hijack.distributed.parallel_state import (
|
||||
# get_context_model_parallel_rank)
|
||||
# from ..zigzag_utils import context_parallel_tensor_all_gather, diff1
|
||||
# if get_context_model_parallel_world_size() > 1 and attn_metadata.prefill_metadata:
|
||||
# hidden_states = context_parallel_tensor_all_gather(hidden_states, zigzag_attn_metadata, dim=0)
|
||||
# if attn_metadata.prefill_metadata and (kv_caches[0] is not None):
|
||||
# file_path = '/workspace/output_base_' + str(hidden_states.shape) + \
|
||||
# '_tp_' + str(get_tensor_model_parallel_world_size()) + '.pth'
|
||||
# if get_context_model_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0:
|
||||
# if os.path.exists(file_path):
|
||||
# print("##################compare################")
|
||||
# hidden_states_base = torch.load(file_path)
|
||||
# print("########output_diff1: ", diff1(hidden_states, hidden_states_base))
|
||||
# else:
|
||||
# print("##################save base################")
|
||||
# torch.save(hidden_states, file_path)
|
||||
|
||||
'''
|
||||
@brief: logits_processor in context parallel need attn_metadata param
|
||||
'''
|
||||
if get_context_model_parallel_world_size() > 1 and model_input.attn_metadata.prefill_metadata:
|
||||
setattr(self.model.logits_processor, 'attn_metadata', zigzag_attn_metadata)
|
||||
else:
|
||||
setattr(self.model.logits_processor, 'attn_metadata', None)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Add time markers for model_executable+compute_logits
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
end_marker = torch.mlu.Event(enable_timing=True)
|
||||
end_marker.record()
|
||||
if use_cuda_graph:
|
||||
self.time_markers = (model_executable.start, end_marker)
|
||||
else:
|
||||
self.time_markers = (start, end_marker)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
indices = model_input.sampling_metadata.selected_token_indices
|
||||
if model_input.is_prompt:
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
elif decode_meta.use_cuda_graph:
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return [output]
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUModelRunner,
|
||||
MLUModelRunner.execute_model,
|
||||
vllm__worker__mlu_model_runner__MLUModelRunner__execute_model)
|
||||
@@ -0,0 +1,35 @@
|
||||
from typing import (Any, Dict, Optional)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from examples.cambricon_custom_func.context_parallel.mlu_hijack.worker.model_runner_base import vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict
|
||||
from vllm.worker.model_runner_base import _init_attn_metadata_from_tensor_dict
|
||||
|
||||
@classmethod
|
||||
def vllm__worker__model_runner__ModelInputForGPUWithSamplingMetadata__from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForGPUWithSamplingMetadata":
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: force apply hijacked function.
|
||||
'''
|
||||
tensor_dict = vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict,
|
||||
vllm__worker__model_runner__ModelInputForGPUWithSamplingMetadata__from_broadcasted_tensor_dict
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
from typing import (Any, Dict)
|
||||
|
||||
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
|
||||
from vllm.worker import model_runner_base
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
def vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict( # type: ignore
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize SamplingMetadata based on broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
|
||||
if selected_token_indices is not None:
|
||||
if 'seq_group_metadata' in tensor_dict.keys() and len(tensor_dict['seq_group_metadata']) > 0:
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: construct sampling metadata.
|
||||
'''
|
||||
sequence_group_to_sample_list = []
|
||||
for seq_group_metadata in tensor_dict['seq_group_metadata']:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_data = seq_group_metadata.seq_data
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
if is_prompt:
|
||||
seq_len = query_len = list(seq_data.values())[0].get_prompt_len()
|
||||
else:
|
||||
seq_len = None
|
||||
query_len = 1
|
||||
prompt_logprob_indices = []
|
||||
sample_indices = seq_ids
|
||||
sequence_group_to_sample = SequenceGroupToSample(seq_ids,
|
||||
sampling_params,
|
||||
seq_data,
|
||||
seq_len,
|
||||
query_len,
|
||||
None, # Generator
|
||||
is_prompt,
|
||||
prompt_logprob_indices,
|
||||
sample_indices)
|
||||
sequence_group_to_sample_list.append(sequence_group_to_sample)
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=sequence_group_to_sample_list,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=len(sequence_group_to_sample_list),
|
||||
)
|
||||
del tensor_dict['seq_group_metadata']
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
else:
|
||||
# An empty SamplingMetadata to signal that the worker should skip
|
||||
# sampling.
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
return tensor_dict
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
model_runner_base,
|
||||
model_runner_base._init_sampling_metadata_from_tensor_dict,
|
||||
vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
@property
|
||||
def vllm__worker__worker__Worker__do_metadata_broadcast(self) -> bool:
|
||||
'''
|
||||
=============================
|
||||
Modify by Context Parallel
|
||||
=============================
|
||||
@brief: do metadata broadcast if cp or tp > 1.
|
||||
'''
|
||||
return self.parallel_config.world_size > 1
|
||||
'''
|
||||
==========================
|
||||
End of Context Parallel
|
||||
==========================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Worker,
|
||||
Worker.do_metadata_broadcast,
|
||||
vllm__worker__worker__Worker__do_metadata_broadcast)
|
||||
@@ -0,0 +1,121 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ObservabilityConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerInputBase)
|
||||
from vllm.worker.worker_base import (extract_previous_hidden_states,
|
||||
LocalOrDistributedWorkerBase,
|
||||
WorkerInput)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_world_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data.update(kwargs)
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: add seq_group metadata to broadcast.
|
||||
'''
|
||||
broadcast_data['seq_group_metadata'] = execute_model_req.seq_group_metadata_list
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_worker_input_from_broadcast(
|
||||
self
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
return self._get_driver_input_and_broadcast(execute_model_req)
|
||||
else:
|
||||
return self._get_worker_input_from_broadcast()
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase._get_driver_input_and_broadcast,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase._get_worker_input_from_broadcast,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_worker_input_from_broadcast)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase.prepare_input,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input)
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import Dict, Optional, Sequence, List
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_rank, get_context_model_parallel_world_size, get_context_model_parallel_group)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.attention import AttentionMetadata
|
||||
import copy
|
||||
|
||||
|
||||
def diff1(result: torch.Tensor, baseline: torch.Tensor):
|
||||
result = result.flatten().float().to('cpu')
|
||||
baseline = baseline.flatten().float().to('cpu')
|
||||
assert result.shape == baseline.shape
|
||||
error = torch.abs(baseline - result)
|
||||
denominator = torch.sum(torch.abs(baseline)).item()
|
||||
eps = 0.0 if denominator > 0 else 1e-9
|
||||
diff1 = torch.sum(error) / (denominator + eps)
|
||||
return diff1.item()
|
||||
|
||||
|
||||
def get_pad_seq(seq_len: int, pad: int):
|
||||
return (seq_len // pad + (int)((seq_len) % (pad) > 0)) * pad
|
||||
|
||||
|
||||
# Gather the partial results of a batch on context parallel groups
|
||||
# together and place them in the order before zigzag splitting
|
||||
def context_parallel_tensor_all_gather_(input_, dim=-1):
|
||||
world_size = get_context_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
|
||||
assert input_size[dim] % 2 == 0, (f"input tensor split dim % 2 != 0")
|
||||
|
||||
gather_list = [torch.empty(input_.shape, dtype=input_.dtype, device=input_.device) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(
|
||||
gather_list, input_, group=get_context_model_parallel_group())
|
||||
|
||||
first = []
|
||||
second = []
|
||||
for i in range(world_size):
|
||||
first_second = torch.split(gather_list[i], gather_list[i].shape[dim] // 2, dim=dim)
|
||||
first.append(first_second[0])
|
||||
second.insert(0, first_second[1])
|
||||
tensor_list = first + second
|
||||
output_tensor = torch.cat(tensor_list, dim = dim).contiguous()
|
||||
return output_tensor
|
||||
|
||||
|
||||
# Gather the partial results of each batch on the context parallel groups together,
|
||||
# place them in the order before zigzag splitting, and remove the pad part.
|
||||
# This function is used for debugging
|
||||
def context_parallel_tensor_all_gather(input, attn_metadata, dim=-1):
|
||||
if dim < 0:
|
||||
dim += input.dim()
|
||||
slice_ = ()
|
||||
for i in range(dim):
|
||||
slice_ + (slice(None))
|
||||
select_list = []
|
||||
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = seq_start_loc.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
start = seq_start_loc[batch].item()
|
||||
end = seq_start_loc[batch + 1].item()
|
||||
slice1 = slice_ + (slice(start, end), )
|
||||
input_ = input[slice1]
|
||||
gather_ = context_parallel_tensor_all_gather_(input_, dim=dim)
|
||||
slice2 = slice_ + (slice(None, attn_metadata.prefill_metadata.seq_lens[batch]), )
|
||||
select = gather_[slice2]
|
||||
select_list.append(select)
|
||||
output = torch.cat(select_list, dim=dim)
|
||||
return output
|
||||
|
||||
|
||||
# Pad one dimension of a tensor so that it is divisible by context_parallel_size * 2,
|
||||
# and then use zigzag method to split it into different context parallel groups
|
||||
def zigzag_split_(tensor: torch.Tensor, dim = -1, pad_value=0):
|
||||
if dim < 0:
|
||||
dim = tensor.dim() + dim
|
||||
split_num = get_context_model_parallel_world_size() * 2
|
||||
pad_num = get_pad_seq(tensor.shape[dim], split_num) - tensor.shape[dim]
|
||||
pad_param = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_num) + (0, 0) * dim
|
||||
tensor_pad = F.pad(tensor, pad_param, value = pad_value)
|
||||
split_size = divide(tensor_pad.size()[dim], split_num)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor_pad, split_size, dim = dim)
|
||||
first = tensor_list[get_context_model_parallel_rank()]
|
||||
second = tensor_list[split_num - get_context_model_parallel_rank() - 1]
|
||||
output_tensor = torch.cat((first, second), dim=dim).contiguous()
|
||||
return output_tensor
|
||||
|
||||
|
||||
# Split each batch of input_ids, positions, attn_metadata.slot_mapping with zigzag method,
|
||||
# and update prefill_metadata.seq_start_loc and prefill_metadata.max_seq_len
|
||||
def zigzag_split(input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
pad_slot_id: int):
|
||||
zigzag_input_ids: List[int] = []
|
||||
zigzag_positions: List[int] = []
|
||||
zigzag_slot_mapping: List[int] = []
|
||||
zigzag_attn_metadata = copy.deepcopy(attn_metadata)
|
||||
seq_lens: List[int] = []
|
||||
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = seq_start_loc.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
start, end = seq_start_loc[batch], seq_start_loc[batch + 1]
|
||||
input_ids_ = input_ids[start : end]
|
||||
positions_ = positions[start : end]
|
||||
zigzag_input_ids_ = zigzag_split_(input_ids_)
|
||||
zigzag_positions_ = zigzag_split_(positions_)
|
||||
zigzag_input_ids.append(zigzag_input_ids_)
|
||||
zigzag_positions.append(zigzag_positions_)
|
||||
seq_lens.append(zigzag_input_ids_.shape[0])
|
||||
slot_mapping_ = attn_metadata.slot_mapping[start : end]
|
||||
zigzag_slot_mapping_ = zigzag_split_(slot_mapping_, pad_value=pad_slot_id)
|
||||
zigzag_slot_mapping.append(zigzag_slot_mapping_)
|
||||
|
||||
zigzag_input_ids = torch.cat(zigzag_input_ids, dim=0)
|
||||
zigzag_positions = torch.cat(zigzag_positions, dim=0)
|
||||
zigzag_slot_mapping = torch.cat(zigzag_slot_mapping, dim=0)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=input_ids.device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device)
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
zigzag_attn_metadata.prefill_metadata.seq_start_loc = seq_start_loc
|
||||
zigzag_attn_metadata.prefill_metadata.query_start_loc = seq_start_loc
|
||||
zigzag_attn_metadata.prefill_metadata.max_seq_len = max_seq_len
|
||||
zigzag_attn_metadata.slot_mapping = zigzag_slot_mapping
|
||||
|
||||
return zigzag_input_ids, zigzag_positions, zigzag_attn_metadata
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
os.environ['CONTEXT_PARALLEL_EN'] = "True"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, max_tokens=16)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="/data/AE/llm/models/Llama-2-7b-hf/", enforce_eager=True, tensor_parallel_size = 2, context_parallel_size = 2, distributed_executor_backend='ray')
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@@ -0,0 +1,26 @@
|
||||
### 简介
|
||||
|
||||
该example是vLLM中进行Expert Parallel的实验,mlu_hijack是对仓库代码的劫持,避免修改主仓库代码
|
||||
|
||||
### 支持模型
|
||||
|
||||
- qwen2_moe
|
||||
- mixtral
|
||||
- custom model
|
||||
- deepseek_v2
|
||||
|
||||
### 支持板卡
|
||||
|
||||
300系列设备只能用于功能测试,性能测试需要其他系列设备。
|
||||
|
||||
### 运行demo
|
||||
```python
|
||||
python examples/cambricon_custom_func/expert_parallel/offline_inference.py
|
||||
```
|
||||
|
||||
### 使用Expert Parallel特性
|
||||
|
||||
- 设置环境变量export EXPERT_PARALLEL_EN=1|True|true|TRUE, LLM主接口传入tensor_parallel_size的同时,传入moe_tp_size或moe_ep_size,或两者都传;
|
||||
- 若只传moe_tp_size和moe_ep_size中的一个,另一个等于tensor_parallel_size除以传入其中一个的除数,所以必须保证传入数可以被tensor_parallel_size整除;
|
||||
- 若moe_tp_size和moe_ep_size都传入,则必须保证moe_tp_size * moe_ep_size == tensor_parallel_size;
|
||||
- 若moe_tp_size和moe_ep_size都不传,则它们默认值等于-1,即不开启专家并行;
|
||||
@@ -0,0 +1,133 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm output -rf
|
||||
mkdir output
|
||||
|
||||
DATA_DIR=/data
|
||||
MODELS_DEEPSEEK_V2=(
|
||||
"${DATA_DIR}/vllm/models/LLM-Research/deepseek-v2"
|
||||
)
|
||||
|
||||
MODELS=(${MODELS_DEEPSEEK_V2[@]})
|
||||
|
||||
# 定义变量
|
||||
use_ray=0
|
||||
use_eager=0
|
||||
use_pp=0
|
||||
# context parameter
|
||||
input_sizes=(1024)
|
||||
output_sizes=(1)
|
||||
# batch_sizes=(1 2 4 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40)
|
||||
batch_sizes=(1 4 8 16 32)
|
||||
|
||||
# decoder parameter
|
||||
# input_sizes=(1)
|
||||
# output_sizes=(128)
|
||||
# batch_sizes=(1 2 4 8 16 32 64 128 256 512 1024 1280 1536 1600 1616 1632 1648 1652 1656 1660 1661 1662 1663 1664 1728 1792 2048)
|
||||
# batch_sizes=(1 4 8 16 32 64 128 256 512 1024 2048)
|
||||
|
||||
tp_sizes=(8)
|
||||
moe_ep_sizes=(8 -1)
|
||||
pp_sizes=(1)
|
||||
|
||||
if [ $use_pp -gt 0 ]; then
|
||||
tp_sizes=(1)
|
||||
moe_ep_sizes=(-1)
|
||||
pp_sizes=(8)
|
||||
BENCHMARK_CMD=benchmarks/benchmark_throughput.py
|
||||
benchmark_option="--backend vllm --num-prompts 1000 --output-json output_throughput.csv --async-engine"
|
||||
else
|
||||
BENCHMARK_CMD=benchmarks/benchmark_latency.py
|
||||
benchmark_option="--num-iters-warmup 1 --num-iters 3 --only_average"
|
||||
fi
|
||||
|
||||
max_position_embeddings=163840
|
||||
|
||||
#export MLU_VISIBLE_DEVICES=4,5,6,7
|
||||
export EXPERT_PARALLEL_EN=true
|
||||
export VLLM_LATENCY_DEBUG=true
|
||||
export VLLM_GRAPH_DEBUG=false
|
||||
# export VLLM_DUMP_MLU_INFO=true
|
||||
export OUTPUT_CSV_PATH=/data/solution-sdk/kangpengtao/tmp/deepseek/output.csv
|
||||
|
||||
ray_option=""
|
||||
if [ $use_ray -gt 0 ]; then
|
||||
ray_option="--distributed-executor-backend ray --ray-workers-use-nsight"
|
||||
fi
|
||||
eager_option=""
|
||||
if [ $use_eager -gt 0 ]; then
|
||||
eager_option="--enforce-eager"
|
||||
fi
|
||||
|
||||
# 遍历所有组合
|
||||
for HF_MODEL in "${MODELS[@]}"; do
|
||||
quantization_option=""
|
||||
if [[ "${HF_MODEL}" == *"sq_per_token_per_channel"* ]]; then
|
||||
quantization_option="--quantization=smoothquant"
|
||||
fi
|
||||
for tp_size in "${tp_sizes[@]}"; do
|
||||
for moe_ep_size in "${moe_ep_sizes[@]}"; do
|
||||
for pp_size in "${pp_sizes[@]}"; do
|
||||
for input_size in "${input_sizes[@]}"; do
|
||||
for output_size in "${output_sizes[@]}"; do
|
||||
for batch_size in "${batch_sizes[@]}"; do
|
||||
max_seq_len_to_capture=$(expr $input_size \+ $output_size)
|
||||
max_num_batched_tokens=$(expr $batch_size \* $input_size)
|
||||
max_model_len=$max_seq_len_to_capture
|
||||
if [ $max_model_len -gt $max_position_embeddings ]; then
|
||||
continue
|
||||
fi
|
||||
# max_num_seqs=256
|
||||
# if [ $max_num_seqs -lt $batch_size ]; then
|
||||
# max_num_seqs=$batch_size
|
||||
# fi
|
||||
max_num_seqs=$batch_size
|
||||
if [ $max_model_len -gt $max_num_batched_tokens ]; then
|
||||
max_num_batched_tokens=$max_model_len
|
||||
fi
|
||||
if [ $max_num_seqs -gt $max_num_batched_tokens ]; then
|
||||
max_num_batched_tokens=$max_num_seqs
|
||||
fi
|
||||
|
||||
pp_option="--pipeline-parallel-size ${pp_size}"
|
||||
tp_option="-tp ${tp_size}"
|
||||
ep_option="--moe-ep-size ${moe_ep_size}"
|
||||
batch_size_option=""
|
||||
if [ $use_pp -le 0 ]; then
|
||||
batch_size_option="--batch-size ${batch_size}"
|
||||
fi
|
||||
|
||||
hf_model_name=$(basename "${HF_MODEL}")
|
||||
LOG_FILE=output/${hf_model_name}_${input_size}_${output_size}_tp_${tp_size}_moe_ep_${moe_ep_size}_pp_${pp_size}_bs_${batch_size}.log
|
||||
echo "Executing ${hf_model_name} with tp_size=${tp_size}, moe_ep_size=${moe_ep_size}, pp_size=${pp_size}, input_size=${input_size}, output_size=${output_size}, batch_size=${batch_size}, max_model_len=${max_model_len}, max_num_batched_tokens=${max_num_batched_tokens}"
|
||||
python3 ${BENCHMARK_CMD} \
|
||||
${benchmark_option} \
|
||||
--trust-remote-code \
|
||||
--max-num-batched-tokens ${max_num_batched_tokens} \
|
||||
--max-model-len ${max_model_len} \
|
||||
--block-size 16 \
|
||||
--model ${HF_MODEL} \
|
||||
--tokenizer ${HF_MODEL} \
|
||||
--dtype bfloat16 \
|
||||
--input-len ${input_size} \
|
||||
--output-len ${output_size} \
|
||||
${pp_option} ${tp_option} ${ep_option} \
|
||||
--max-seq-len-to-capture ${max_seq_len_to_capture} \
|
||||
--max-num-seqs ${max_num_seqs} \
|
||||
${batch_size_option} \
|
||||
${eager_option} ${ray_option} ${quantization_option} \
|
||||
2>&1 | tee ${LOG_FILE}
|
||||
# 检查日志文件中是否有 torch.OutOfMemoryError, Ceil of batch 或is larger than mlu blocks
|
||||
if grep -E -q "torch\.OutOfMemoryError|Ceil of batch|is larger than mlu blocks" "$LOG_FILE"; then
|
||||
echo "Found one or more specified errors in the log file."
|
||||
break
|
||||
else
|
||||
echo "No specified errors found."
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,147 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm output -rf
|
||||
mkdir output
|
||||
|
||||
DATA_DIR=/data
|
||||
MODELS_DEEPSEEK_V2=(
|
||||
"${DATA_DIR}/vllm/models/LLM-Research/deepseek-v2"
|
||||
)
|
||||
|
||||
MODELS=(${MODELS_DEEPSEEK_V2[@]})
|
||||
|
||||
# 定义变量
|
||||
use_ray=0
|
||||
use_eager=0
|
||||
use_pp=0
|
||||
use_kernel_analysis=0
|
||||
# context parameter
|
||||
input_sizes=(1024)
|
||||
output_sizes=(1)
|
||||
# batch_sizes=(1 2 4 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40)
|
||||
batch_sizes=(1 4 8 16 32)
|
||||
|
||||
# decoder parameter
|
||||
# input_sizes=(1)
|
||||
# output_sizes=(128)
|
||||
# batch_sizes=(1 2 4 8 16 32 64 128 256 512 1024 1280 1536 1600 1616 1632 1648 1652 1656 1660 1661 1662 1663 1664 1728 1792 2048)
|
||||
# batch_sizes=(1 4 8 16 32 64 128 256 512 1024 2048)
|
||||
|
||||
tp_sizes=(8)
|
||||
moe_ep_sizes=(8 -1)
|
||||
pp_sizes=(1)
|
||||
|
||||
if [ $use_pp -gt 0 ]; then
|
||||
tp_sizes=(1)
|
||||
moe_ep_sizes=(-1)
|
||||
pp_sizes=(8)
|
||||
BENCHMARK_CMD=benchmarks/benchmark_throughput.py
|
||||
benchmark_option="--backend vllm --num-prompts 1000 --output-json output_throughput.csv --async-engine"
|
||||
else
|
||||
BENCHMARK_CMD=benchmarks/benchmark_latency.py
|
||||
benchmark_option="--num-iters-warmup 1 --num-iters 3 --only_average"
|
||||
fi
|
||||
|
||||
max_position_embeddings=163840
|
||||
|
||||
#export MLU_VISIBLE_DEVICES=4,5,6,7
|
||||
export EXPERT_PARALLEL_EN=true
|
||||
export VLLM_LATENCY_DEBUG=true
|
||||
export VLLM_GRAPH_DEBUG=false
|
||||
# export VLLM_DUMP_MLU_INFO=true
|
||||
export OUTPUT_CSV_PATH=/data/solution-sdk/kangpengtao/tmp/deepseek/output.csv
|
||||
|
||||
ray_option=""
|
||||
if [ $use_ray -gt 0 ]; then
|
||||
ray_option="--distributed-executor-backend ray --ray-workers-use-nsight"
|
||||
fi
|
||||
|
||||
record_option=""
|
||||
if [ $use_kernel_analysis -gt 0 ]; then
|
||||
# ref: https://wiki.cambricon.com/pages/viewpage.action?pageId=434445235
|
||||
export CNPERF_KERNEL_ANALYSIS=1
|
||||
record_option="--pmu --capture_range=cnpx --cnpx_include kangpengtao --cnpx_exclude kangpengtao_exec --events tp_core__write_bytes,tp_core__read_bytes,tp_memcore__write_bytes,tp_memcore__read_bytes,tp_core__lt_cycles,tp_core__csimd_pre_cycles,tp_core__csimd_post_cycles"
|
||||
use_eager=1
|
||||
fi
|
||||
|
||||
eager_option=""
|
||||
if [ $use_eager -gt 0 ]; then
|
||||
eager_option="--enforce-eager"
|
||||
fi
|
||||
|
||||
# 遍历所有组合
|
||||
for HF_MODEL in "${MODELS[@]}"; do
|
||||
quantization_option=""
|
||||
if [[ "${HF_MODEL}" == *"sq_per_token_per_channel"* ]]; then
|
||||
quantization_option="--quantization=smoothquant"
|
||||
fi
|
||||
for tp_size in "${tp_sizes[@]}"; do
|
||||
for moe_ep_size in "${moe_ep_sizes[@]}"; do
|
||||
for pp_size in "${pp_sizes[@]}"; do
|
||||
for input_size in "${input_sizes[@]}"; do
|
||||
for output_size in "${output_sizes[@]}"; do
|
||||
for batch_size in "${batch_sizes[@]}"; do
|
||||
max_seq_len_to_capture=$(expr $input_size \+ $output_size)
|
||||
max_num_batched_tokens=$(expr $batch_size \* $input_size)
|
||||
max_model_len=$max_seq_len_to_capture
|
||||
if [ $max_model_len -gt $max_position_embeddings ]; then
|
||||
continue
|
||||
fi
|
||||
# max_num_seqs=256
|
||||
# if [ $max_num_seqs -lt $batch_size ]; then
|
||||
# max_num_seqs=$batch_size
|
||||
# fi
|
||||
max_num_seqs=$batch_size
|
||||
if [ $max_model_len -gt $max_num_batched_tokens ]; then
|
||||
max_num_batched_tokens=$max_model_len
|
||||
fi
|
||||
if [ $max_num_seqs -gt $max_num_batched_tokens ]; then
|
||||
max_num_batched_tokens=$max_num_seqs
|
||||
fi
|
||||
|
||||
pp_option="--pipeline-parallel-size ${pp_size}"
|
||||
tp_option="-tp ${tp_size}"
|
||||
ep_option="--moe-ep-size ${moe_ep_size}"
|
||||
batch_size_option=""
|
||||
if [ $use_pp -le 0 ]; then
|
||||
batch_size_option="--batch-size ${batch_size}"
|
||||
fi
|
||||
|
||||
hf_model_name=$(basename "${HF_MODEL}")
|
||||
LOG_FILE=output/${hf_model_name}_${input_size}_${output_size}_tp_${tp_size}_moe_ep_${moe_ep_size}_pp_${pp_size}_bs_${batch_size}.log
|
||||
echo "Executing ${hf_model_name} with tp_size=${tp_size}, moe_ep_size=${moe_ep_size}, pp_size=${pp_size}, input_size=${input_size}, output_size=${output_size}, batch_size=${batch_size}, max_model_len=${max_model_len}, max_num_batched_tokens=${max_num_batched_tokens}"
|
||||
dltrace_data_name="dltrace_data_${hf_model_name}_${tp_size}_${moe_ep_size}_${pp_size}_${input_size}_${output_size}_${batch_size}_${max_model_len}_${max_num_batched_tokens}"
|
||||
rm dltrace_data -rf
|
||||
rm cnperf_data_* -rf
|
||||
CNPERF_VLOG_LEVEL=0-40 cnperf-cli record ${record_option} python3 ${BENCHMARK_CMD} \
|
||||
--trust-remote-code \
|
||||
--max-num-batched-tokens ${max_num_batched_tokens} \
|
||||
--max-model-len ${max_model_len} \
|
||||
--block-size 16 \
|
||||
--model ${HF_MODEL} \
|
||||
--tokenizer ${HF_MODEL} \
|
||||
--dtype bfloat16 \
|
||||
--input-len ${input_size} \
|
||||
--output-len ${output_size} \
|
||||
${pp_option} ${tp_option} ${ep_option} \
|
||||
--max-seq-len-to-capture ${max_seq_len_to_capture} \
|
||||
--max-num-seqs ${max_num_seqs} \
|
||||
${batch_size_option} \
|
||||
${eager_option} ${ray_option} ${quantization_option} \
|
||||
2>&1 | tee ${LOG_FILE}
|
||||
# 检查日志文件中是否有 torch.OutOfMemoryError, Ceil of batch 或is larger than mlu blocks
|
||||
if grep -E -q "torch\.OutOfMemoryError|Ceil of batch|is larger than mlu blocks" "$LOG_FILE"; then
|
||||
echo "Found one or more specified errors in the log file."
|
||||
break
|
||||
else
|
||||
echo "No specified errors found."
|
||||
fi
|
||||
mv dltrace_data ${dltrace_data_name}
|
||||
mv cnperf_data_* ${dltrace_data_name}/
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,34 @@
|
||||
#/bin/bash
|
||||
|
||||
# export EXPERT_PARALLEL_EN=True
|
||||
# export VLLM_LATENCY_DEBUG=True
|
||||
|
||||
rm output/client -rf
|
||||
mkdir -p output/client
|
||||
|
||||
PORT=32345
|
||||
MODEL_PATH="/data/vllm/sq_per_token_per_channel/deepseek_v2_temp"
|
||||
input_sizes=(1024)
|
||||
output_sizes=(1)
|
||||
# batch_sizes=(1 2 4 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40)
|
||||
batch_sizes=(32)
|
||||
for input_size in "${input_sizes[@]}"; do
|
||||
for output_size in "${output_sizes[@]}"; do
|
||||
for batch_size in "${batch_sizes[@]}"; do
|
||||
hf_model_name=$(basename "${HF_MODEL}")
|
||||
LOG_FILE=output/client/${hf_model_name}_${input_size}_${output_size}_bs_${batch_size}.log
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model ${MODEL_PATH} \
|
||||
--trust-remote-code \
|
||||
--dataset-name random \
|
||||
--num-prompts 1000 \
|
||||
--port ${PORT} \
|
||||
--request-rate inf \
|
||||
--random_input_len $input_size \
|
||||
--random-output-len ${output_size} \
|
||||
--max-concurrency ${batch_size} \
|
||||
2>&1 | tee ${LOG_FILE}
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,2 @@
|
||||
print("Apply Expert Parallel Demo!")
|
||||
from . import model_executor
|
||||
@@ -0,0 +1,5 @@
|
||||
from .layers import sparse_moe_mlp
|
||||
from .models import custom
|
||||
from .models import mixtral
|
||||
from .models import qwen2_moe
|
||||
from .models import deepseek_v2
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Inference-only MOE model.
|
||||
|
||||
Tensor Parallel evenly splits each expert's weight and distributes them to different ranks,
|
||||
which means each rank holds partial weight of all experts.
|
||||
While Expert Parallel evenly distributes some of the experts' full weight to different ranks,
|
||||
which means each rank holds part of the experts' full weight.
|
||||
|
||||
As a result, each rank in the Tensor Parallel group receives all tokens' hidden states for all experts,
|
||||
then computes using the partial weights, while for Expert Parallel, each rank only receives
|
||||
part of tokens' hidden states for experts on this rank, then computes using the full weights.
|
||||
|
||||
When both Tensor Parallel and Expert Parallel are enabled, each rank handles
|
||||
a portion of the expert weights matrices (as in EP mode) and these weights are further sliced
|
||||
across ranks (as in TP mode). This hybrid approach aims to balance the workload more evenly across ranks,
|
||||
enhancing efficiency and reducing the likelihood of bottlenecks associated with EP mode alone.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group)
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, get_moe_tensor_parallel_group,
|
||||
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, get_moe_expert_parallel_group)
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu._mlu_utils import get_device_major_capability
|
||||
|
||||
|
||||
def vllm__mlu_hijack__model_executor__layers__feed_forward__SparseMoeMlp____init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
has_bias: bool,
|
||||
skip_bias_add: bool = False,
|
||||
renormalize:bool = False,
|
||||
hidden_act: str = "silu",
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_use_fused_moe: bool = False,
|
||||
expert_group: int = 1,
|
||||
topk_group: int = 1,
|
||||
):
|
||||
super(SparseMoeMlp, self).__init__()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tensor_model_parallel_group()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj_name = up_proj_name
|
||||
self.is_gated = is_gated
|
||||
self.down_proj_name = down_proj_name
|
||||
self.has_bias = has_bias
|
||||
self.renormalize = renormalize
|
||||
self.hidden_act = hidden_act
|
||||
self.quant_config = quant_config
|
||||
self.is_use_fused_moe = is_use_fused_moe
|
||||
self.expert_group = expert_group
|
||||
self.topk_group = topk_group
|
||||
if get_device_major_capability() == 3:
|
||||
self.is_use_fused_moe = False
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add moe relative distribution
|
||||
'''
|
||||
self.moe_tp_size = get_moe_tensor_parallel_world_size()
|
||||
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
||||
self.moe_tp_group = get_moe_tensor_parallel_group()
|
||||
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
||||
self.moe_ep_group = get_moe_expert_parallel_group()
|
||||
|
||||
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
|
||||
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
|
||||
self.skip_bias_add = True if self.moe_tp_rank > 0 else False
|
||||
|
||||
assert self.num_total_experts >= self.moe_ep_size, (
|
||||
f"need num_total_experts:{self.num_total_experts} >= moe_ep_size:{self.moe_ep_size}")
|
||||
|
||||
assert self.intermediate_size % self.moe_tp_size == 0, (
|
||||
f"need intermediate_size:{self.intermediate_size} % moe_tp_size:{self.moe_tp_size} == 0")
|
||||
|
||||
self.num_experts_per_rank = (self.num_total_experts + self.moe_ep_size - 1) // self.moe_ep_size
|
||||
if self.moe_ep_rank + 1 == self.moe_ep_size and self.num_total_experts % self.moe_ep_size:
|
||||
self.num_experts_per_rank = self.num_total_experts % self.moe_ep_size
|
||||
|
||||
self.start_expert_id = self.moe_ep_rank * ((self.num_total_experts + self.moe_ep_size - 1) // self.moe_ep_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_rank
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
up_proj_name=self.up_proj_name,
|
||||
is_gated=self.is_gated,
|
||||
down_proj_name=self.down_proj_name,
|
||||
bias=self.has_bias,
|
||||
quant_config=self.quant_config,
|
||||
skip_bias_add=self.skip_bias_add,
|
||||
reduce_results=False,
|
||||
tp_group=self.moe_tp_group) for idx in range(self.num_experts_per_rank)
|
||||
])
|
||||
|
||||
self.init_pack_param()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(SparseMoeMlp,
|
||||
SparseMoeMlp.__init__,
|
||||
vllm__mlu_hijack__model_executor__layers__feed_forward__SparseMoeMlp____init__)
|
||||
@@ -0,0 +1,183 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm_mlu.transformers_utils.configs import CustomConfig
|
||||
from vllm_mlu.model_executor.custom_model.custom import CustomDecoderLayer, CustomAttention, _NORM_DICT
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, is_per_tensor_smoothquant,
|
||||
is_per_token_smoothquant, quant_fusion_with_rmsnorm,
|
||||
quant_fusion_with_layernorm)
|
||||
|
||||
|
||||
class CustomMoeBlock(SparseMoeMlp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=config.is_gated,
|
||||
down_proj_name="down_proj",
|
||||
has_bias=config.mlp_bias,
|
||||
skip_bias_add=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True)
|
||||
|
||||
self.config = config
|
||||
self.rank = self.tp_rank
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = None
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=config.is_gated,
|
||||
down_proj_name='down_proj',
|
||||
bias=config.mlp_bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
self.shared_expert_gate = ReplicatedLinear(config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
gate_output = self.shared_expert_gate(hidden_states)
|
||||
shared_output = F.sigmoid(gate_output[0]) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
residual_ = None if self.rank > 0 else residual
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify bt_ops.fused_moe to forward_experts
|
||||
'''
|
||||
final_hidden_states = self.forward_experts(hidden_states, router_logits, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add comment to explain use_parallel_residual usage
|
||||
'''
|
||||
# use_parallel_residual = True: x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||
# use_parallel_residual = False:
|
||||
# if apply_residual_connection_post_layernorm:
|
||||
# x_attn = ln1(x) + attn(ln1(x))
|
||||
# x_mlp = ln2(x_attn) + mlp(ln2(x_attn))
|
||||
# else:
|
||||
# x_attn = x + attn(ln1(x))
|
||||
# x_mlp = x_attn + mlp(ln2(x_attn))
|
||||
# When use_parallel_residual = True, x is shared between attn and mlp, so we only need to
|
||||
# reduce after x + attn(ln1(x)) + mlp(ln2(x)) and don't need reduce here
|
||||
# But when use_parallel_residual = False, mlp layer uses attn layer's output, so need reduce
|
||||
# when mlp is finished.
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
reduce_results = (self.config.use_parallel_residual == False)
|
||||
if reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
def vllm__mlu_hijack__model_executor__custom_model__custom__CustomDecoderLayer____init__(
|
||||
self,
|
||||
config: CustomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super(CustomDecoderLayer, self).__init__()
|
||||
self.config = config
|
||||
self.self_attn = CustomAttention(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
mlp_bias = getattr(config, "mlp_bias", False) or getattr(config, "bias", False)
|
||||
is_gated = getattr(config, "is_gated", False)
|
||||
|
||||
if config.num_experts is not None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: nothing changed, only use the CustomMoeBlock class in this file
|
||||
'''
|
||||
self.mlp = CustomMoeBlock(config=config,
|
||||
quant_config=quant_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
self.mlp = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=self.config.hidden_act,
|
||||
up_proj_name='up_proj',
|
||||
is_gated=is_gated,
|
||||
down_proj_name='down_proj',
|
||||
bias=mlp_bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=(self.config.use_parallel_residual and mlp_bias),
|
||||
reduce_results = (self.config.use_parallel_residual == False))
|
||||
|
||||
self.input_layernorm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
|
||||
self.post_attention_layernorm = _NORM_DICT[self.config.norm_type](config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
# perf per-tensor sq cases by fusing quantization in layernorm
|
||||
self.is_per_tesnor_sq_perf_cases = (is_per_tensor_smoothquant(quant_config) and
|
||||
not self.config.apply_residual_connection_post_layernorm)
|
||||
self.is_per_token_sq_perf_cases = (is_per_token_smoothquant(quant_config) and
|
||||
not self.config.apply_residual_connection_post_layernorm)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
self.is_moe = config.num_experts is not None
|
||||
self.use_rmsnorm = self.config.norm_type == "rmsnorm"
|
||||
if not self.is_moe:
|
||||
self.mlp.up_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_mlp_layernorm = None
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(CustomDecoderLayer,
|
||||
CustomDecoderLayer.__init__,
|
||||
vllm__mlu_hijack__model_executor__custom_model__custom__CustomDecoderLayer____init__)
|
||||
@@ -0,0 +1,222 @@
|
||||
|
||||
import re
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||
from vllm_mlu.model_executor.models.deepseek_v2 import DeepseekV2MoE
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size)
|
||||
|
||||
|
||||
def vllm_mlu__model_executor__models__deepseek_v2__DeepseekV2MoE____init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super(DeepseekV2MoE, self).__init__(num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
has_bias=False,
|
||||
skip_bias_add=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True,
|
||||
expert_group=config.n_group,
|
||||
topk_group=config.topk_group)
|
||||
self.config = config
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
if self.moe_tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Moe Tensor parallel size {self.moe_tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}.")
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace MLP with FeedForward.
|
||||
'''
|
||||
self.shared_experts = FeedForward(hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
up_proj_name='gate_up_proj',
|
||||
is_gated=True,
|
||||
down_proj_name='down_proj',
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
# expert parallel modification start
|
||||
moe_ep_rank = get_moe_expert_parallel_rank()
|
||||
moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
num_total_experts = self.config.n_routed_experts
|
||||
start_expert_id = moe_ep_rank * ((num_total_experts + moe_ep_size - 1) // moe_ep_size)
|
||||
# expert parallel modification end
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "mlp.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
|
||||
'''
|
||||
name = name.replace(weight_name, param_name)
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if (("mlp.experts." in name or "mlp.shared_experts." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(DeepseekV2MoE,
|
||||
DeepseekV2MoE.__init__,
|
||||
vllm_mlu__model_executor__models__deepseek_v2__DeepseekV2MoE____init__)
|
||||
MluHijackObject.apply_hijack(DeepseekV2ForCausalLM,
|
||||
DeepseekV2ForCausalLM.load_weights,
|
||||
vllm__module_executor__models__deepseek_v2__DeepseekV2ForCausalLM__load_weights)
|
||||
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import re
|
||||
import vllm
|
||||
from torch import nn
|
||||
from typing import List, Optional, Tuple, Iterable
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.models.mixtral import MixtralForCausalLM
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
|
||||
def vllm__module_executor__models__mixtral__MixtralForCausalLM__load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
# expert parallel modification start
|
||||
moe_ep_rank = get_moe_expert_parallel_rank()
|
||||
moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
num_total_experts = self.config.num_local_experts
|
||||
start_expert_id = moe_ep_rank * ((num_total_experts + moe_ep_size - 1) // moe_ep_size)
|
||||
# expert parallel modification end
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("w13", "w1", 0),
|
||||
("w13", "w3", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "block_sparse_moe.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("block_sparse_moe.experts." in name) and (name not in params_dict)):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("block_sparse_moe.experts." in name) and (name not in params_dict)):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MixtralForCausalLM,
|
||||
MixtralForCausalLM.load_weights,
|
||||
vllm__module_executor__models__mixtral__MixtralForCausalLM__load_weights)
|
||||
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import re
|
||||
from typing import Optional, Iterable, Tuple
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
|
||||
def vllm__module_executor__models__qwen2moe__Qwen2MoeForCausalLM__load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params and cal start expert id
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
# expert parallel modification start
|
||||
moe_ep_rank = get_moe_expert_parallel_rank()
|
||||
moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
num_total_experts = self.config.num_experts
|
||||
start_expert_id = moe_ep_rank * ((num_total_experts + moe_ep_size - 1) // moe_ep_size)
|
||||
# expert parallel modification end
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete expert_params_mapping for no useless
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace expert_id in weight to named_expert_id in params_dict
|
||||
'''
|
||||
if start_expert_id > 0 and "mlp.experts." in name:
|
||||
expert_str = re.search(r'experts\.\d+', name).group(0)
|
||||
expert_id=int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
old_expert_name = f"experts.{expert_id}"
|
||||
new_expert_name = f"experts.{named_expert_id}"
|
||||
name = name.replace(old_expert_name, new_expert_name)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete if "mlp.experts" in name: continue condition
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition and delete useless if name not in params_dict: continue condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_expert." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: delete for mapping in expert_params_mapping condition
|
||||
'''
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
print_warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add expert skiped condition
|
||||
'''
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if (("mlp.experts." in name or "mlp.shared_expert." in name or "mlp.shared_expert_gate." in name)
|
||||
and name not in params_dict):
|
||||
continue
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForCausalLM.load_weights,
|
||||
vllm__module_executor__models__qwen2moe__Qwen2MoeForCausalLM__load_weights)
|
||||
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
os.environ['EXPERT_PARALLEL_EN'] = "True"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
model_dir="/data/AE/llm/models/Qwen1.5-MoE-A2.7B"
|
||||
tp_size = 2
|
||||
moe_ep_size=2
|
||||
is_check_act_range = True
|
||||
input_seq_len=64
|
||||
output_seq_len=1
|
||||
batch=1
|
||||
# max_position_embedding=1024
|
||||
max_model_len=input_seq_len + output_seq_len
|
||||
# if max_model_len < max_position_embedding:
|
||||
# max_model_len = max_position_embedding
|
||||
max_num_batched_tokens=input_seq_len * batch
|
||||
if max_model_len > max_num_batched_tokens:
|
||||
max_num_batched_tokens=max_model_len
|
||||
max_num_seqs = batch
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model=model_dir,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
dtype='bfloat16',
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
tensor_parallel_size=tp_size,
|
||||
moe_ep_size=moe_ep_size,
|
||||
)
|
||||
|
||||
if is_check_act_range:
|
||||
llm.llm_engine.model_executor._run_workers("setup_smooth_hook", is_save_moe_info=True)
|
||||
|
||||
llm.llm_engine.model_executor._run_workers("remove_hooks")
|
||||
act_range = llm.llm_engine.model_executor._run_workers("get_act_range")
|
||||
print(f"len(act_range)={len(act_range)}")
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@@ -0,0 +1,48 @@
|
||||
#/bin/bash
|
||||
|
||||
rm output/server -rf
|
||||
mkdir -p output/server
|
||||
|
||||
PORT=32345
|
||||
use_ray=0
|
||||
use_pp=1
|
||||
use_eager=0
|
||||
|
||||
eager_option=""
|
||||
if [ $use_eager -gt 0 ]; then
|
||||
eager_option="--enforce-eager"
|
||||
fi
|
||||
|
||||
ray_option=""
|
||||
if [ $use_ray -gt 0 ]; then
|
||||
ray_option="--worker-use-ray"
|
||||
ray stop --force
|
||||
fi
|
||||
|
||||
export VLLM_ENGINE_ITERATION_TIMEOUT_S=180
|
||||
MODEL_PATH="/data/vllm/sq_per_token_per_channel/deepseek_v2_temp"
|
||||
|
||||
if [ $use_pp -gt 0 ]; then
|
||||
parallel_option="--pipeline-parallel-size=8"
|
||||
else
|
||||
parallel_option="--tensor-parallel-size=8"
|
||||
fi
|
||||
|
||||
# TP8
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--disable-log-requests \
|
||||
--port ${PORT} \
|
||||
--model ${MODEL_PATH} \
|
||||
--trust-remote-code \
|
||||
--swap-space 16 \
|
||||
${parallel_option} \
|
||||
--max-num-batched-tokens=40960 \
|
||||
--max-model-len=1034 \
|
||||
--block-size=16 \
|
||||
--dtype=bfloat16 \
|
||||
--max-seq-len-to-capture=1034 \
|
||||
--max-num-seqs=40 \
|
||||
--quantization=smoothquant \
|
||||
${eager_option} \
|
||||
${ray_option} \
|
||||
2>&1 | tee output/server/server.log
|
||||
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import sys
|
||||
import ray
|
||||
import gc
|
||||
import contextlib
|
||||
import os
|
||||
os.environ['CONTEXT_PARALLEL_EN'] = "True"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
def cleanup():
|
||||
"""Release occupied resources and reset parallel_state"""
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import destroy_model_parallel
|
||||
destroy_model_parallel()
|
||||
from vllm.distributed import destroy_distributed_environment
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
if not current_platform.is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
|
||||
def run_vllm(prompts, sampling_params, tp, cp):
|
||||
"""Run LLM"""
|
||||
llm = LLM(model="/data/AE/llm/models/Llama-2-7b-hf/",
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size = tp,
|
||||
context_parallel_size = cp,
|
||||
distributed_executor_backend='ray')
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
return outputs
|
||||
|
||||
def test_context_parallel():
|
||||
"""Compare the output results of cp1 and cp2"""
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, max_tokens=16)
|
||||
outputs_1 = run_vllm(prompts, sampling_params, tp=1, cp=2)
|
||||
cleanup()
|
||||
outputs_2 = run_vllm(prompts, sampling_params, tp=1, cp=1)
|
||||
cleanup()
|
||||
generated_text_1 = [output.outputs[0].text for output in outputs_1]
|
||||
generated_text_2 = [output.outputs[0].text for output in outputs_2]
|
||||
assert generated_text_1 == generated_text_2
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import sys
|
||||
import ray
|
||||
import gc
|
||||
import contextlib
|
||||
import os
|
||||
os.environ['CONTEXT_PARALLEL_EN'] = "True"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
def cleanup():
|
||||
"""Release occupied resources and reset parallel_state"""
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import destroy_model_parallel
|
||||
destroy_model_parallel()
|
||||
from vllm.distributed import destroy_distributed_environment
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
if not current_platform.is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
|
||||
def run_vllm(prompts, sampling_params, tp, cp, use_kv8=False):
|
||||
"""Run LLM"""
|
||||
kwargs = dict()
|
||||
kwargs['model']="/data/AE/llm/models/Llama-2-7b-hf/"
|
||||
kwargs['enforce_eager']=True,
|
||||
kwargs['tensor_parallel_size'] = tp
|
||||
kwargs['context_parallel_size'] = cp
|
||||
kwargs['distributed_executor_backend']='ray'
|
||||
kwargs['kv_cache_dtype'] = 'int8'
|
||||
|
||||
llm = LLM(**kwargs)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
return outputs
|
||||
|
||||
def test_context_parallel_with_kv8():
|
||||
"""Compare the output results of cp1 and cp2 with kv cache int8."""
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, max_tokens=16)
|
||||
outputs_1 = run_vllm(prompts, sampling_params, tp=1, cp=2)
|
||||
cleanup()
|
||||
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import sys
|
||||
import ray
|
||||
import gc
|
||||
import contextlib
|
||||
import numpy as np
|
||||
import os
|
||||
os.environ['EXPERT_PARALLEL_EN'] = "True"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
def string_list_to_float(text_list: list):
|
||||
'''
|
||||
convert string list to float list
|
||||
'''
|
||||
txt = np.array(text_list)
|
||||
max_len = max(len(s) for s in txt)
|
||||
string_to_float = lambda s: np.array([ord(char) for char in s.ljust(max_len)])
|
||||
txt_char = np.array([string_to_float(s) for s in txt])
|
||||
txt_float = txt_char.astype('float32')
|
||||
return txt_float
|
||||
|
||||
def compute_diff_text(baseline_text: list, compare_text: list):
|
||||
'''
|
||||
compute the outputs diff1 and diff2
|
||||
'''
|
||||
baseline = string_list_to_float(baseline_text)
|
||||
compare = string_list_to_float(compare_text)
|
||||
error = np.abs(baseline - compare)
|
||||
diff1 = np.sum(error) / np.sum(np.abs(baseline))
|
||||
diff2 = np.sqrt(np.sum(error**2)/np.sum(baseline**2))
|
||||
return diff1, diff2
|
||||
|
||||
def cleanup():
|
||||
'''Release occupied resources and reset parallel_state'''
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import destroy_model_parallel
|
||||
destroy_model_parallel()
|
||||
from vllm.distributed import destroy_distributed_environment
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
if not current_platform.is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
|
||||
def run_vllm(prompts, sampling_params, tp, mtp=-1, mep=-1, model_dir="/data/AE/llm/models/Qwen1.5-MoE-A2.7B/"):
|
||||
'''Run LLM'''
|
||||
llm = LLM(model=model_dir,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=tp,
|
||||
moe_tp_size=mtp,
|
||||
moe_ep_size=mep)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
return outputs
|
||||
|
||||
def test_expert_parallel():
|
||||
"""Compare the output results of tp4 and mtp=1, 2"""
|
||||
qwen2_moe_model_dir = "/data/AE/llm/models/Qwen1.5-MoE-A2.7B"
|
||||
eps = 1e-6
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, max_tokens=1)
|
||||
outputs_1 = run_vllm(prompts, sampling_params, tp=2, mtp=1, model_dir=qwen2_moe_model_dir)
|
||||
cleanup()
|
||||
outputs_2 = run_vllm(prompts, sampling_params, tp=2, mtp=2, model_dir=qwen2_moe_model_dir)
|
||||
cleanup()
|
||||
generated_text_1 = [output.outputs[0].text for output in outputs_1]
|
||||
generated_text_2 = [output.outputs[0].text for output in outputs_2]
|
||||
diff1, diff2 = compute_diff_text(generated_text_1, generated_text_2)
|
||||
assert diff1 <= eps and diff2 <= eps, (
|
||||
f"qwen2_moe generated_1({generated_text_1}) and generated_2{generated_text_2} diff error")
|
||||
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from logging import Logger
|
||||
|
||||
def init_logger(name: str) -> Logger:
|
||||
"""Initialize loggers for benchmarks module,
|
||||
and keep the configuration consistent with the vllm module"""
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
vllm_logger = logging.Logger.manager.loggerDict.get('vllm', None)
|
||||
if vllm_logger:
|
||||
logger.setLevel(vllm_logger.level)
|
||||
logger.propagate = vllm_logger.propagate
|
||||
logger.handlers = vllm_logger.handlers
|
||||
|
||||
return logger
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
from vllm.config import ParallelConfig, TokenizerPoolConfig
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.platforms import current_platform
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__config__ParallelConfig___init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: Optional[bool] = None,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
||||
ray_workers_use_nsight: bool = False,
|
||||
placement_group: Optional["PlacementGroup"] = None,
|
||||
distributed_executor_backend: Optional[Union[
|
||||
str, Type["ExecutorBase"]]] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.distributed_executor_backend = distributed_executor_backend
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
self.tokenizer_pool_config = tokenizer_pool_config
|
||||
self.ray_workers_use_nsight = ray_workers_use_nsight
|
||||
self.placement_group = placement_group
|
||||
|
||||
'''
|
||||
==========================
|
||||
Modify by vllm_mlu
|
||||
==========================
|
||||
@brief: modify world_size
|
||||
'''
|
||||
self.context_parallel_size = self.context_parallel_size
|
||||
self.moe_tp_size = self.moe_tp_size
|
||||
self.moe_ep_size = self.moe_ep_size
|
||||
|
||||
self.world_size = pipeline_parallel_size * tensor_parallel_size * self.context_parallel_size
|
||||
'''
|
||||
=======================
|
||||
End of MLU Hijack
|
||||
=======================
|
||||
'''
|
||||
if worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
elif not self.use_ray:
|
||||
raise ValueError(f"worker-use-ray can't be used with "
|
||||
f"distributed executor backend "
|
||||
f"'{self.distributed_executor_backend}'.")
|
||||
|
||||
if current_platform.is_tpu() and self.world_size > 1:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
if self.distributed_executor_backend != "ray":
|
||||
raise ValueError(
|
||||
"TPU backend only supports Ray for distributed inference.")
|
||||
|
||||
if current_platform.is_hpu() and self.world_size > 1:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
if self.distributed_executor_backend != "ray":
|
||||
raise ValueError(
|
||||
"HPU backend only supports Ray for distributed inference.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
from vllm.executor import ray_utils
|
||||
backend = "mp"
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if (current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size):
|
||||
if not ray_found:
|
||||
raise ValueError("Unable to load Ray which is "
|
||||
"required for multi-node inference, "
|
||||
"please install Ray with `pip install "
|
||||
"ray`.") from ray_utils.ray_import_err
|
||||
backend = "ray"
|
||||
elif ray_found:
|
||||
if self.placement_group:
|
||||
backend = "ray"
|
||||
else:
|
||||
from ray import is_initialized as ray_is_initialized
|
||||
if ray_is_initialized():
|
||||
from ray.util import get_current_placement_group
|
||||
if get_current_placement_group():
|
||||
backend = "ray"
|
||||
self.distributed_executor_backend = backend
|
||||
logger.info("Defaulting to use %s for distributed inference",
|
||||
backend)
|
||||
|
||||
self._verify_args()
|
||||
self.rank: int = 0
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(ParallelConfig,
|
||||
ParallelConfig.__init__,
|
||||
vllm__config__ParallelConfig___init__)
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import communication_op
|
||||
from . import parallel_state
|
||||
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor, tp_group: Any = None) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group(tp_group).all_reduce(input_)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
dim: int = -1, tp_group: Any = None) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group(tp_group).all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1, tp_group: Any = None) -> Optional[torch.Tensor]:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group(tp_group).gather(input_, dst, dim)
|
||||
@@ -0,0 +1,339 @@
|
||||
import torch
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import (init_model_parallel_group, get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank, get_world_group, get_pp_group,
|
||||
GroupCoordinator)
|
||||
import vllm.distributed.parallel_state as parallel_state_org
|
||||
from vllm.distributed.parallel_state import model_parallel_is_initialized as model_parallel_is_initialized_org
|
||||
from vllm.distributed.parallel_state import destroy_model_parallel as destroy_model_parallel_org
|
||||
|
||||
def get_tp_group(tp_group: Any = None) -> GroupCoordinator:
|
||||
if tp_group is not None:
|
||||
return tp_group
|
||||
assert parallel_state_org._TP is not None, ("tensor model parallel group is not initialized")
|
||||
return parallel_state_org._TP
|
||||
|
||||
_CP: Optional[GroupCoordinator] = None
|
||||
|
||||
def get_cp_group() -> GroupCoordinator:
|
||||
assert _CP is not None, ("context parallel group is not initialized")
|
||||
return _CP
|
||||
|
||||
# kept for backward compatibility
|
||||
get_context_model_parallel_group = get_cp_group
|
||||
|
||||
_MOE_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
def get_moe_tp_group() -> GroupCoordinator:
|
||||
assert _MOE_TP is not None, ("moe tensor parallel group is not initialized")
|
||||
return _MOE_TP
|
||||
|
||||
# kept for backward compatibility
|
||||
get_moe_tensor_parallel_group = get_moe_tp_group
|
||||
|
||||
_MOE_EP: Optional[GroupCoordinator] = None
|
||||
|
||||
def get_moe_ep_group() -> GroupCoordinator:
|
||||
assert _MOE_EP is not None, ("moe expert parallel group is not initialized")
|
||||
return _MOE_EP
|
||||
|
||||
|
||||
# kept for backward compatibility
|
||||
get_moe_expert_parallel_group = get_moe_ep_group
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
parallel_config: ParallelConfig,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||
parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||
parallelism.
|
||||
|
||||
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||
4 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||
2 pipeline model-parallel groups:
|
||||
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: get parallel_size from parallel_config and valid world_size
|
||||
'''
|
||||
tensor_model_parallel_size = parallel_config.tensor_parallel_size
|
||||
pipeline_model_parallel_size = parallel_config.pipeline_parallel_size
|
||||
context_model_parallel_size = parallel_config.context_parallel_size
|
||||
moe_tensor_parallel_size = parallel_config.moe_tp_size
|
||||
moe_expert_parallel_size = parallel_config.moe_ep_size
|
||||
|
||||
if (world_size !=
|
||||
tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size):
|
||||
raise RuntimeError(
|
||||
f"world_size ({world_size}) is not equal to "
|
||||
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
||||
f"pipeline_model_parallel_size ({pipeline_model_parallel_size}) x"
|
||||
f"context_model_parallel_size ({context_model_parallel_size})")
|
||||
|
||||
if (moe_tensor_parallel_size < 1 or moe_expert_parallel_size < 1 or tensor_model_parallel_size !=
|
||||
moe_tensor_parallel_size * moe_expert_parallel_size):
|
||||
raise RuntimeError(
|
||||
f"tensor_model_parallel_size ({world_size}) is not equal to "
|
||||
f"moe_tensor_parallel_size ({moe_tensor_parallel_size}) x "
|
||||
f"moe_expert_parallel_size ({moe_expert_parallel_size})")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
num_tensor_model_parallel_groups: int = (world_size //
|
||||
tensor_model_parallel_size)
|
||||
assert parallel_state_org._TP is None, ("tensor model parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
parallel_state_org._TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="tp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = (world_size //
|
||||
pipeline_model_parallel_size)
|
||||
assert parallel_state_org._PP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||
group_ranks.append(ranks)
|
||||
# pipeline parallel does not need custom allreduce
|
||||
parallel_state_org._PP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False,
|
||||
group_name="pp")
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add _CP, _MOE_TP, MOE_EP
|
||||
'''
|
||||
# Build the context parallel groups.
|
||||
num_context_model_parallel_groups: int = (world_size //
|
||||
context_model_parallel_size)
|
||||
global _CP
|
||||
assert _CP is None, (
|
||||
"context parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_context_model_parallel_groups):
|
||||
ranks = list(range(i, context_model_parallel_size * tensor_model_parallel_size + i, tensor_model_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
# message queue broadcaster is set to be used in context parallel group
|
||||
_CP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="cp")
|
||||
|
||||
# Build the moe tensor parallel groups.
|
||||
global _MOE_TP
|
||||
assert _MOE_TP is None, ("moe tensor parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_expert_parallel_size):
|
||||
ranks = list(range(i * tensor_model_parallel_size + j, (i + 1) * tensor_model_parallel_size,
|
||||
moe_expert_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
# message queue broadcaster is set to be used in moe tensor parallel group
|
||||
_MOE_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="moe_tp")
|
||||
|
||||
# Build the moe expert parallel groups.
|
||||
global _MOE_EP
|
||||
assert _MOE_EP is None, ("moe expert parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_tensor_parallel_size):
|
||||
ranks = range(i * tensor_model_parallel_size + j * moe_expert_parallel_size,
|
||||
i * tensor_model_parallel_size + (j + 1) * moe_expert_parallel_size)
|
||||
group_ranks.append(ranks)
|
||||
|
||||
# message queue broadcaster is set to be used in moe expert parallel group
|
||||
_MOE_EP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="moe_ep")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
parallel_config: ParallelConfig,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace all parallel_size to parallel_config
|
||||
'''
|
||||
initialize_model_parallel(parallel_config, backend)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: check parallel_size with prefix parallel_config
|
||||
'''
|
||||
assert (
|
||||
get_tensor_model_parallel_world_size() == parallel_config.tensor_model_parallel_size
|
||||
), ("tensor parallel group already initialized, but of unexpected size: "
|
||||
f"{get_tensor_model_parallel_world_size()=} vs. "
|
||||
f"{parallel_config.tensor_model_parallel_size=}")
|
||||
pp_world_size = get_pp_group().world_size
|
||||
assert (pp_world_size == parallel_config.pipeline_model_parallel_size), (
|
||||
"pipeline parallel group already initialized, but of unexpected size: "
|
||||
f"{pp_world_size=} vs. "
|
||||
f"{parallel_config.pipeline_model_parallel_size=}")
|
||||
cp_world_size = get_cp_group().world_size
|
||||
assert (cp_world_size == parallel_config.context_parallel_size), (
|
||||
"context parallel group already initialized, but of unexpected size: "
|
||||
f"{cp_world_size=} vs. "
|
||||
f"{parallel_config.context_parallel_size=}")
|
||||
moe_tp_world_size = get_moe_tp_group().world_size
|
||||
assert (moe_tp_world_size == parallel_config.moe_tp_size), (
|
||||
"moe tensor parallel group already initialized, but of unexpected size: "
|
||||
f"{moe_tp_world_size=} vs. "
|
||||
f"{parallel_config.moe_tp_size=}")
|
||||
moe_ep_world_size = get_moe_ep_group().world_size
|
||||
assert (moe_ep_world_size == parallel_config.moe_ep_size), (
|
||||
"moe expert parallel group already initialized, but of unexpected size: "
|
||||
f"{moe_ep_world_size=} vs. "
|
||||
f"{parallel_config.moe_ep_size=}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor, pipeline, context, moe parallel groups are initialized."""
|
||||
return model_parallel_is_initialized_org and (_CP is not None and _CP is not None) and (
|
||||
_MOE_TP is not None and _MOE_TP is not None) and (_MOE_EP is not None and _MOE_EP is not None)
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
destroy_model_parallel_org()
|
||||
global _CP
|
||||
if _CP:
|
||||
_CP.destroy()
|
||||
_CP = None
|
||||
|
||||
global _MOE_TP
|
||||
if _MOE_TP:
|
||||
_MOE_TP.destroy()
|
||||
_MOE_TP = None
|
||||
|
||||
global _MOE_EP
|
||||
if _MOE_EP:
|
||||
_MOE_EP.destroy()
|
||||
_MOE_EP = None
|
||||
|
||||
|
||||
def get_context_model_parallel_world_size():
|
||||
"""Return world size for the context parallel group."""
|
||||
return get_cp_group().world_size
|
||||
|
||||
|
||||
def get_context_model_parallel_rank():
|
||||
"""Return my rank for the context parallel group."""
|
||||
return get_cp_group().rank_in_group
|
||||
|
||||
|
||||
def get_moe_tensor_parallel_world_size():
|
||||
"""Return world size for the moe tensor parallel group."""
|
||||
return get_moe_tp_group().world_size
|
||||
|
||||
|
||||
def get_moe_tensor_parallel_rank():
|
||||
"""Return my rank for the moe tensor parallel group."""
|
||||
return get_moe_tp_group().rank_in_group
|
||||
|
||||
|
||||
def get_moe_expert_parallel_world_size():
|
||||
"""Return world size for the moe expert parallel group."""
|
||||
return get_moe_ep_group().world_size
|
||||
|
||||
|
||||
def get_moe_expert_parallel_rank():
|
||||
"""Return my rank for the moe expert parallel group."""
|
||||
return get_moe_ep_group().rank_in_group
|
||||
|
||||
|
||||
def get_parallel_world_size_with_group(group):
|
||||
"""Return world size for the special group."""
|
||||
if group is not None:
|
||||
return group.world_size
|
||||
else:
|
||||
return get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
def get_parallel_rank_with_group(group):
|
||||
"""Return my rank for the special group."""
|
||||
if group is not None:
|
||||
return group.rank_in_group
|
||||
else:
|
||||
return get_tensor_model_parallel_rank()
|
||||
@@ -0,0 +1 @@
|
||||
from . import arg_utils
|
||||
@@ -0,0 +1,141 @@
|
||||
import argparse
|
||||
import torch
|
||||
from vllm.config import VllmConfig, ParallelConfig
|
||||
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config_org = EngineArgs.create_engine_config
|
||||
vllm__engine__arg_utils__EngineArgs__add_cli_args_org = EngineArgs.add_cli_args
|
||||
vllm__engine__arg_utils__EngineArgs__from_cli_args_org = EngineArgs.from_cli_args
|
||||
vllm__engine__arg_utils__AsyncEngineArgs__from_cli_args_org = AsyncEngineArgs.from_cli_args
|
||||
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs__create_engine_config(self, ) -> VllmConfig:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: chunked parallel pipeline only support batch size = 1 yet.
|
||||
'''
|
||||
if CHUNKED_PIPELINE_PARALLEL_EN:
|
||||
self.max_num_seqs = 1
|
||||
logger.info("Reset max_num_seqs to 1 as the chunked parallel pipeline mode "
|
||||
"only supports batch size to 1.")
|
||||
'''
|
||||
@brief: disable custom_all_reduce, re-set block_size to support paged and unpaged mode.
|
||||
'''
|
||||
# MLU not support custom all reduce
|
||||
self.disable_custom_all_reduce = True
|
||||
BlockSizeInfo.set_block_size(self.block_size)
|
||||
if not USE_PAGED and self.enable_chunked_prefill:
|
||||
raise ValueError("Not support chunked_prefill in unpaged mode.")
|
||||
|
||||
# set parallel_config context_parallel_size, moe_tp_size, moe_ep_size
|
||||
self.context_parallel_size = getattr(self, "context_parallel_size", 1)
|
||||
self.moe_tp_size = getattr(self, "moe_tp_size", -1)
|
||||
self.moe_ep_size = getattr(self, "moe_ep_size", -1)
|
||||
# check context parallel whether supported or not
|
||||
if CONTEXT_PARALLEL_EN:
|
||||
if self.context_parallel_size > 1 and get_device_major_capability() == 3:
|
||||
raise ValueError('Context parallel does not support MLU370.')
|
||||
else:
|
||||
if self.context_parallel_size > 1:
|
||||
raise ValueError('Context parallel does not support when CONTEXT_PARALLEL_EN=False')
|
||||
# check expert parallel whether supported or not
|
||||
if not EXPERT_PARALLEL_EN and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
raise ValueError('Expert parallel does not support when EXPERT_PARALLEL_EN=False')
|
||||
|
||||
ParallelConfig.context_parallel_size = self.context_parallel_size
|
||||
|
||||
# set parallel_config moe_tp_size and moe_ep_size
|
||||
if self.moe_tp_size < 1 and self.moe_ep_size < 1:
|
||||
moe_tp_size = self.tensor_parallel_size
|
||||
moe_ep_size = 1
|
||||
elif self.moe_tp_size >= 1 and self.moe_ep_size < 1:
|
||||
moe_tp_size = self.moe_tp_size
|
||||
moe_ep_size = self.tensor_parallel_size // self.moe_tp_size
|
||||
elif self.moe_tp_size < 1 and self.moe_ep_size >= 1:
|
||||
moe_tp_size = self.tensor_parallel_size // self.moe_ep_size
|
||||
moe_ep_size = self.moe_ep_size
|
||||
else:
|
||||
moe_tp_size = self.moe_tp_size
|
||||
moe_ep_size = self.moe_ep_size
|
||||
assert moe_tp_size * moe_ep_size == self.tensor_parallel_size, (
|
||||
f"tensor_parallel_size ({self.tensor_parallel_size}) is not equal to "
|
||||
f"moe_tp_size ({self.moe_tp_size}) x moe_ep_size ({self.moe_ep_size})"
|
||||
"or moe_tp_size and moe_ep_size should be -1 or one of them should be -1")
|
||||
|
||||
ParallelConfig.moe_tp_size = moe_tp_size
|
||||
ParallelConfig.moe_ep_size = moe_ep_size
|
||||
|
||||
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(self)
|
||||
engine_config.cache_config.block_size = BlockSizeInfo.BLOCK_SIZE
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return engine_config
|
||||
|
||||
|
||||
@staticmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__add_cli_args(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser = vllm__engine__arg_utils__EngineArgs__add_cli_args_org(parser)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add --context-parallel-size, --moe-tp-size and --moe-ep-size
|
||||
'''
|
||||
parser.add_argument('--context-parallel-size',
|
||||
'-cp',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of context parallel replicas')
|
||||
parser.add_argument('--moe-tp-size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Number of moe tensor parallel replicas')
|
||||
parser.add_argument('--moe-ep-size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Number of moe expert parallel replicas')
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return parser
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||
if cls == AsyncEngineArgs:
|
||||
engine_args = vllm__engine__arg_utils__AsyncEngineArgs__from_cli_args_org(args)
|
||||
else:
|
||||
engine_args = vllm__engine__arg_utils__EngineArgs__from_cli_args_org(args)
|
||||
setattr(engine_args, 'context_parallel_size', getattr(args, "context_parallel_size"))
|
||||
setattr(engine_args, 'moe_tp_size', getattr(args, "moe_tp_size"))
|
||||
setattr(engine_args, 'moe_ep_size', getattr(args, "moe_ep_size"))
|
||||
return engine_args
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.create_engine_config,
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.add_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__add_cli_args)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.from_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__from_cli_args)
|
||||
MluHijackObject.apply_hijack(AsyncEngineArgs,
|
||||
AsyncEngineArgs.from_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__from_cli_args)
|
||||
@@ -0,0 +1 @@
|
||||
from . import llm
|
||||
@@ -0,0 +1,98 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
TaskOption)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__entrypoints__llm__LLM____init__org = LLM.__init__
|
||||
|
||||
def vllm__entrypoints__llm__LLM____init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False.
|
||||
'''
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add cp and ep parameter
|
||||
'''
|
||||
# pop context_parallel_size
|
||||
EngineArgs.context_parallel_size = kwargs.pop("context_parallel_size", 1)
|
||||
# pop moe_tp_size and moe_ep_size
|
||||
EngineArgs.moe_tp_size = kwargs.pop("moe_tp_size", -1)
|
||||
# pop moe_ep_size and moe_ep_size
|
||||
EngineArgs.moe_ep_size = kwargs.pop("moe_ep_size", -1)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
vllm__entrypoints__llm__LLM____init__org(
|
||||
self=self,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task=task,
|
||||
override_pooler_config=override_pooler_config,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LLM,
|
||||
LLM.__init__,
|
||||
vllm__entrypoints__llm__LLM____init__)
|
||||
@@ -0,0 +1,7 @@
|
||||
print("Apply Custom VLLM Demo!")
|
||||
from . import distributed
|
||||
from . import engine
|
||||
from . import entrypoints
|
||||
from . import worker
|
||||
from . import config
|
||||
from . import model_executor
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import layers
|
||||
from . import parameter
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import linear
|
||||
from . import feed_forward
|
||||
@@ -0,0 +1,93 @@
|
||||
from typing import Optional, Any
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear
|
||||
)
|
||||
from vllm_mlu.mlu_hijack_utils import set_is_gated, MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from ....mlu_hijack.distributed.parallel_state import (get_parallel_rank_with_group, get_parallel_world_size_with_group)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__mlu_hijack__model_executor__layers__feed_forward__FeedForward____init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
up_proj_name: str,
|
||||
is_gated: bool,
|
||||
down_proj_name: str,
|
||||
bias: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
skip_bias_add: bool = False,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
):
|
||||
super(FeedForward, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.is_gated = is_gated
|
||||
self.bias = bias
|
||||
self.up_proj_name = up_proj_name
|
||||
self.down_proj_name = down_proj_name
|
||||
self.quant_config = quant_config
|
||||
self.is_initialized = False
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.reduce_results = reduce_results
|
||||
self.use_bt_ffn = True if quant_config is None else False
|
||||
set_is_gated(self.is_gated)
|
||||
self.tp_size = get_parallel_world_size_with_group(tp_group)
|
||||
self.tp_rank = get_parallel_rank_with_group(tp_group)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group parameter at the end of each linear class
|
||||
'''
|
||||
self.tp_group = tp_group
|
||||
# up_proj with gate or not
|
||||
if self.is_gated:
|
||||
up_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}",
|
||||
tp_group=tp_group)
|
||||
else:
|
||||
up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{up_proj_name}",
|
||||
tp_group=tp_group)
|
||||
self.register_module(up_proj_name, up_proj)
|
||||
|
||||
# down_proj
|
||||
down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
skip_bias_add=skip_bias_add,
|
||||
reduce_results=reduce_results,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{down_proj_name}",
|
||||
tp_group=tp_group)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
self.register_module(down_proj_name, down_proj)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(FeedForward,
|
||||
FeedForward.__init__,
|
||||
vllm__mlu_hijack__model_executor__layers__feed_forward__FeedForward____init__)
|
||||
@@ -0,0 +1,696 @@
|
||||
from typing import Optional, List, Any, Tuple
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, split_tensor_along_last_dim)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, LinearBase, ColumnParallelLinear,
|
||||
MergedColumnParallelLinear, RowParallelLinear, adjust_marlin_shard,
|
||||
adjust_scalar_to_fused_array)
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from ....mlu_hijack.distributed.parallel_state import (get_parallel_rank_with_group, get_parallel_world_size_with_group,
|
||||
get_tp_group)
|
||||
from ....mlu_hijack.distributed.communication_op import (tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_gather)
|
||||
|
||||
vllm__model_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__LinearBase____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
):
|
||||
vllm__model_executor__layers__linear__LinearBase____init__org(self=self,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank to support moe expert parallel
|
||||
'''
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
|
||||
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__ColumnParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix, tp_group)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
@brief: move checking output_sizes logic from MergedColumnParallelLinear to here
|
||||
'''
|
||||
tp_size = self.tp_world_size
|
||||
|
||||
if output_sizes is not None:
|
||||
assert all(output_size_var % tp_size == 0 for output_size_var in output_sizes)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
assert self.quant_method is not None
|
||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__ColumnParallelLinear__weight_loader(
|
||||
self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
param_data = param.data
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__ColumnParallelLinear__forward(
|
||||
self, input_, smooth_quant_scale: Optional[torch.Tensor] = None):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add input_scale parameter.
|
||||
'''
|
||||
if smooth_quant_scale is not None:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias,
|
||||
input_scale=smooth_quant_scale)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group param to tensor_model_parallel_all_gather
|
||||
'''
|
||||
output = tensor_model_parallel_all_gather(output_parallel, self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__ColumnParallelLinear__extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", output_features={self.output_size_per_partition}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
'''
|
||||
s += f", tp_size={self.tp_world_size}"
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
s += f", gather_output={self.gather_output}"
|
||||
return s
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__MergedColumnParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: move checking output_sizes logic from MergedColumnParallelLinear to ColumnParallelLinear.__init__
|
||||
'''
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
super(MergedColumnParallelLinear, self).__init__(input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
output_sizes=self.output_sizes,
|
||||
prefix=prefix,
|
||||
tp_group=tp_group)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__MergedColumnParallelLinear__weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
tp_size = self.tp_world_size
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 2:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for per-tensor scale to load scalar into fused array.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
current_shard_offset = 0
|
||||
shard_offsets: List[Tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
tp_size = self.tp_world_size
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__MergedColumnParallelLinear__weight_loader_v2(self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
if loaded_shard_id is None:
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=0)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
'''
|
||||
tp_size = self.tp_world_size
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
def vllm__model_executor__layers__linear__RowParallelLinear____init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_group: Any = None,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix, tp_group)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
'''
|
||||
self.tp_size = self.tp_world_size
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__RowParallelLinear__weight_loader(
|
||||
self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
@brief: modify get_tensor_model_parallel_world_size() to self.tp_world_size
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
tp_size = self.tp_world_size
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
weight_shape = list(loaded_weight.shape)
|
||||
if input_dim:
|
||||
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if input_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__linear__RowParallelLinear__forward(
|
||||
self,
|
||||
input_,
|
||||
residual: Optional[torch.Tensor] = None
|
||||
):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
residual_ = None if self.tp_rank > 0 else residual
|
||||
'''
|
||||
=====================================================
|
||||
Modify by custom vllm_mlu
|
||||
=====================================================
|
||||
@brief: abandon original reduce if parallel_num is set
|
||||
'''
|
||||
is_parallel_enable = hasattr(self.quant_method, 'parallel_num') and get_is_prompt()
|
||||
'''
|
||||
=====================================================
|
||||
End of custom MLU Hijack
|
||||
=====================================================
|
||||
'''
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_,
|
||||
residual=residual_)
|
||||
'''
|
||||
=============================
|
||||
Modify by custom vllm_mlu
|
||||
=============================
|
||||
@brief: when preload_size is set, call GroupCoordinator.all_reduce() directly and
|
||||
use async_op to set all_reduce paralleled with preload
|
||||
'''
|
||||
if self.reduce_results and self.tp_size > 1 and not is_parallel_enable:
|
||||
if hasattr(self, 'preload_size') and self.preload_size > 0 and not self.is_prompt:
|
||||
handle = get_tp_group(self.tp_group).all_reduce(output_parallel, async_op=True)
|
||||
_MB = 1 << 20
|
||||
mlu_ops.preload(self.preloaded_weights[0].data, self.preload_size * _MB)
|
||||
preloaded_weights_size = self.preloaded_weights[0].numel() * self.preloaded_weights[0].element_size()
|
||||
if preloaded_weights_size < (self.preload_size * _MB) and len(self.preloaded_weights) > 1:
|
||||
mlu_ops.preload(self.preloaded_weights[1].data, (self.preload_size * _MB) - preloaded_weights_size)
|
||||
handle.wait()
|
||||
output = output_parallel
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
|
||||
'''
|
||||
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
else:
|
||||
output = output_parallel
|
||||
'''
|
||||
=========================
|
||||
End of custom MLU Hijack
|
||||
=========================
|
||||
'''
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LinearBase,
|
||||
LinearBase.__init__,
|
||||
vllm__model_executor__layers__linear__LinearBase____init__)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.__init__,
|
||||
vllm__model_executor__layers__linear__ColumnParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.weight_loader,
|
||||
vllm__model_executor__layers__linear__ColumnParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.forward,
|
||||
vllm__model_executor__layers__linear__ColumnParallelLinear__forward)
|
||||
MluHijackObject.apply_hijack(ColumnParallelLinear,
|
||||
ColumnParallelLinear.extra_repr,
|
||||
vllm__model_executor__layers__linear__ColumnParallelLinear__extra_repr)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.__init__,
|
||||
vllm__model_executor__layers__linear__MergedColumnParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.weight_loader,
|
||||
vllm__model_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear.weight_loader_v2,
|
||||
vllm__model_executor__layers__linear__MergedColumnParallelLinear__weight_loader_v2)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.__init__,
|
||||
vllm__model_executor__layers__linear__RowParallelLinear____init__)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.weight_loader,
|
||||
vllm__model_executor__layers__linear__RowParallelLinear__weight_loader)
|
||||
MluHijackObject.apply_hijack(RowParallelLinear,
|
||||
RowParallelLinear.forward,
|
||||
vllm__model_executor__layers__linear__RowParallelLinear__forward)
|
||||
@@ -0,0 +1,173 @@
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from ..distributed.parallel_state import (get_parallel_rank_with_group, get_parallel_world_size_with_group)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__parameter__BasevLLMParameter____init__(self, data: torch.Tensor, weight_loader: Callable, tp_group: Any = None):
|
||||
"""
|
||||
Initialize the BasevLLMParameter
|
||||
|
||||
:param data: torch tensor with the parameter data
|
||||
:param weight_loader: weight loader callable
|
||||
|
||||
:returns: a torch.nn.parameter
|
||||
"""
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add self.tp_group, world_size and tp_rank to support moe expert parallel
|
||||
'''
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
|
||||
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__parameter___ColumnvLLMParameter__load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
tp_rank * shard_size, shard_size)
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
def vllm__model_executor__parameter___ColumnvLLMParameter__load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
|
||||
param_data = self.data
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
tp_rank * shard_size, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def vllm__model_executor__parameter___ColumnvLLMParameter__load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
shard_id = kwargs.get("shard_id")
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
|
||||
param_data = self.data
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
shard_id * shard_size, shard_size)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def vllm__model_executor__parameter__RowvLLMParameter__load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify get_tensor_model_parallel_rank() to self.tp_rank
|
||||
'''
|
||||
tp_rank = self.tp_rank
|
||||
'''
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
'''
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(self.input_dim,
|
||||
tp_rank * shard_size, shard_size)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(BasevLLMParameter,
|
||||
BasevLLMParameter.__init__,
|
||||
vllm__model_executor__parameter__BasevLLMParameter____init__)
|
||||
MluHijackObject.apply_hijack(_ColumnvLLMParameter,
|
||||
_ColumnvLLMParameter.load_column_parallel_weight,
|
||||
vllm__model_executor__parameter___ColumnvLLMParameter__load_column_parallel_weight)
|
||||
MluHijackObject.apply_hijack(_ColumnvLLMParameter,
|
||||
_ColumnvLLMParameter.load_merged_column_weight,
|
||||
vllm__model_executor__parameter___ColumnvLLMParameter__load_merged_column_weight)
|
||||
MluHijackObject.apply_hijack(_ColumnvLLMParameter,
|
||||
_ColumnvLLMParameter.load_qkv_weight,
|
||||
vllm__model_executor__parameter___ColumnvLLMParameter__load_qkv_weight)
|
||||
MluHijackObject.apply_hijack(RowvLLMParameter,
|
||||
RowvLLMParameter.load_row_parallel_weight,
|
||||
vllm__model_executor__parameter__RowvLLMParameter__load_row_parallel_weight)
|
||||
@@ -0,0 +1 @@
|
||||
from . import mlu_worker
|
||||
@@ -0,0 +1,192 @@
|
||||
import gc
|
||||
import os
|
||||
import torch
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import init_distributed_environment, set_custom_all_reduce
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.worker.mlu_worker import MLUWorker, _check_if_gpu_supports_dtype
|
||||
from vllm_mlu.worker.mlu_worker import MLUWorker_V2
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from ..distributed.parallel_state import ensure_model_parallel_initialized
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
||||
from ..distributed.parallel_state import (get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size,
|
||||
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__worker__mlu_worker__init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
backend='cncl')
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add context_parallel_size, moe_tp_size, moe_ep_size
|
||||
'''
|
||||
ensure_model_parallel_initialized(parallel_config=parallel_config)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__worker__mlu_worker__MLUWorker__init_device(self) -> None:
|
||||
if self.device_config.device.type == "mlu":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_CNCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("CNCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"mlu:{self.local_rank}")
|
||||
torch.mlu.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.mlu.empty_cache()
|
||||
self.init_gpu_memory = torch.mlu.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: @brief: modify to vllm__worker__mlu_worker__init_worker_distributed_environment
|
||||
'''
|
||||
vllm__worker__mlu_worker__init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method, self.local_rank)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
|
||||
def default_act_range_value():
|
||||
return {
|
||||
"x": None,
|
||||
"split": None,
|
||||
"is_linear": False,
|
||||
"is_qkv": False,
|
||||
"q_proj_size": 0,
|
||||
"num_kv_head_replicas": 1,
|
||||
"is_merge": False,
|
||||
"input_id": [],
|
||||
"self_rank": 0,
|
||||
"rank": None,
|
||||
"tensor_rank": None,
|
||||
"tp_world_size": None,
|
||||
"moe_tp_rank": None,
|
||||
"moe_tp_world_size": None,
|
||||
"moe_ep_rank": None,
|
||||
"moe_ep_world_size": None,
|
||||
"weight": None,
|
||||
}
|
||||
|
||||
def vllm_mlu__worker__mlu_worker__MLUWorker_V2__setup_smooth_hook(self,
|
||||
is_save_input_id: bool = False,
|
||||
is_save_moe_info: bool = False):
|
||||
model = self.model_runner.model
|
||||
self.act_range = defaultdict(default_act_range_value)
|
||||
self.hooks = []
|
||||
linear_class_list = (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
other_class_list = (VocabParallelEmbedding, ParallelLMHead)
|
||||
class_list = linear_class_list + other_class_list
|
||||
row_class_list = (RowParallelLinear)
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, FeedForward):
|
||||
m.use_bt_ffn = False
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.is_use_fused_moe = False
|
||||
|
||||
if isinstance(m, class_list):
|
||||
is_linear = True if isinstance(m, linear_class_list) else False
|
||||
split_type = "row" if isinstance(m, row_class_list) else "col"
|
||||
self.act_range[name]["split"] = split_type
|
||||
self.act_range[name]["is_linear"] = is_linear
|
||||
if isinstance(m, QKVParallelLinear):
|
||||
self.act_range[name]["is_qkv"] = True
|
||||
self.act_range[name]["q_proj_size"] = m.num_heads * m.head_size
|
||||
self.act_range[name]["num_kv_head_replicas"] = m.num_kv_head_replicas
|
||||
self.act_range[name]["is_merge"] = isinstance(m, MergedColumnParallelLinear)
|
||||
if is_save_moe_info:
|
||||
self.act_range[name]["rank"] = torch.distributed.get_rank()
|
||||
self.act_range[name]["tensor_rank"] = get_tensor_model_parallel_rank()
|
||||
self.act_range[name]["tp_world_size"] = get_tensor_model_parallel_world_size()
|
||||
self.act_range[name]["moe_tp_rank"] = get_moe_tensor_parallel_rank()
|
||||
self.act_range[name]["moe_tp_world_size"] = get_moe_tensor_parallel_world_size()
|
||||
self.act_range[name]["moe_ep_rank"] = get_moe_expert_parallel_rank()
|
||||
self.act_range[name]["moe_ep_world_size"] = get_moe_expert_parallel_world_size()
|
||||
if ".expert." in name:
|
||||
self.act_range[name]["weight"] = m.weight
|
||||
logger.info(f"rank:{self.rank}, add hook to {name}, is_linear:{is_linear}, split_type:{split_type}")
|
||||
self.hooks.append(
|
||||
m.register_forward_hook(
|
||||
functools.partial(self.stat_input_hook,
|
||||
name=name,
|
||||
act_range=self.act_range,
|
||||
is_linear=is_linear,
|
||||
is_save_input_id=is_save_input_id)))
|
||||
|
||||
|
||||
def vllm_mlu__worker__mlu_worker__MLUWorker_V2__get_act_range(self):
|
||||
act_range = defaultdict(default_act_range_value)
|
||||
for layer_name, layer_range in self.act_range.items():
|
||||
for tensor_key, tensor_value in layer_range.items():
|
||||
if isinstance(tensor_value, torch.Tensor):
|
||||
act_range[layer_name][tensor_key] = tensor_value.to("cpu")
|
||||
elif tensor_key == "input_id" and isinstance(tensor_value, list):
|
||||
input_id_len = len(tensor_value)
|
||||
for i in range(input_id_len):
|
||||
if isinstance(tensor_value[i], torch.Tensor):
|
||||
act_range[layer_name][tensor_key].append(tensor_value[i].to("cpu"))
|
||||
else:
|
||||
act_range[layer_name][tensor_key].append(tensor_value[i])
|
||||
else:
|
||||
act_range[layer_name][tensor_key] = tensor_value
|
||||
|
||||
return act_range
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUWorker,
|
||||
MLUWorker.init_device,
|
||||
vllm__worker__mlu_worker__MLUWorker__init_device)
|
||||
MluHijackObject.apply_hijack(MLUWorker,
|
||||
"setup_smooth_hook",
|
||||
vllm_mlu__worker__mlu_worker__MLUWorker_V2__setup_smooth_hook)
|
||||
MluHijackObject.apply_hijack(MLUWorker,
|
||||
"get_act_range",
|
||||
vllm_mlu__worker__mlu_worker__MLUWorker_V2__get_act_range)
|
||||
Reference in New Issue
Block a user