[Ops] Fix bug in register_custom_ops without forward_context (#2883)
### What this PR does / why we need it?
This PR fixed the bug in register_custom_ops without forward_context. We
set try-except to consider this situation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
7920de0a2a
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -139,11 +139,13 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||||
# buffer size for gate up prefetch
|
# buffer size for gate up prefetch
|
||||||
"MLP_GATE_UP_PREFETCH_SIZE":
|
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE":
|
||||||
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
lambda: int(
|
||||||
|
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
||||||
# buffer size for down proj prefetch
|
# buffer size for down proj prefetch
|
||||||
"MLP_DOWN_PREFETCH_SIZE":
|
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE":
|
||||||
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
lambda: int(
|
||||||
|
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
||||||
# Whether to enable dense model and general optimizations for better performance.
|
# Whether to enable dense model and general optimizations for better performance.
|
||||||
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
|
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
|
||||||
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
|
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
tensor_model_parallel_reduce_scatter)
|
tensor_model_parallel_reduce_scatter)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.logger import logger
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
@@ -14,12 +15,18 @@ import vllm_ascend.envs as envs_ascend
|
|||||||
|
|
||||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||||
residual: torch.Tensor) -> torch.Tensor:
|
residual: torch.Tensor) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
logger.info("Forward context is None, skipping the operation.")
|
||||||
|
return residual
|
||||||
|
|
||||||
if x.size(0) != residual.size(0):
|
if x.size(0) != residual.size(0):
|
||||||
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
|
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||||
assert flashcomm_v1_enabled is True, (
|
assert flashcomm_v1_enabled is True, (
|
||||||
"Currently, this situation only occurs "
|
"Currently, this situation only occurs "
|
||||||
"when flashcomm_v1 is enabled")
|
"when flashcomm_v1 is enabled")
|
||||||
pad_size = get_forward_context().pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -31,19 +38,31 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
|
|||||||
|
|
||||||
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||||
label: bool) -> torch.Tensor:
|
label: bool) -> torch.Tensor:
|
||||||
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
|
try:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
logger.info("Forward context is None, skipping the operation.")
|
||||||
|
return x
|
||||||
|
|
||||||
|
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||||
if flashcomm_v1_enabled and label:
|
if flashcomm_v1_enabled and label:
|
||||||
x = tensor_model_parallel_all_gather(x, 0)
|
x = tensor_model_parallel_all_gather(x, 0)
|
||||||
pad_size = get_forward_context().pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = x[:-pad_size, :]
|
x = x[:-pad_size, :]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||||
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
|
try:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
logger.info("Forward context is None, skipping the operation.")
|
||||||
|
return tensor_model_parallel_all_reduce(x)
|
||||||
|
|
||||||
|
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||||
if flashcomm_v1_enabled:
|
if flashcomm_v1_enabled:
|
||||||
pad_size = get_forward_context().pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
x = F.pad(x, (0, 0, 0, pad_size))
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||||
@@ -53,7 +72,12 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||||
prefix: str) -> None:
|
prefix: str) -> None:
|
||||||
|
try:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
logger.info("Forward context is None, skipping the operation.")
|
||||||
|
return
|
||||||
|
|
||||||
if not forward_context.prefetch_mlp_enabled:
|
if not forward_context.prefetch_mlp_enabled:
|
||||||
return
|
return
|
||||||
model_instance = forward_context.model_instance
|
model_instance = forward_context.model_instance
|
||||||
@@ -67,9 +91,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
|||||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||||
|
|
||||||
with torch.npu.stream(prefetch_stream):
|
with torch.npu.stream(prefetch_stream):
|
||||||
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
|
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
||||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
||||||
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
|
x_dependency, mlp_gate_up_prefetch_size)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +103,12 @@ def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
||||||
|
try:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
logger.info("Forward context is None, skipping the operation.")
|
||||||
|
return
|
||||||
|
|
||||||
if not forward_context.prefetch_mlp_enabled:
|
if not forward_context.prefetch_mlp_enabled:
|
||||||
return
|
return
|
||||||
forward_context.prefetch_mlp_down_proj = True
|
forward_context.prefetch_mlp_down_proj = True
|
||||||
@@ -91,9 +120,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
|||||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||||
|
|
||||||
with torch.npu.stream(prefetch_stream):
|
with torch.npu.stream(prefetch_stream):
|
||||||
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
|
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
||||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
||||||
x_dependency, MLP_DOWN_PREFETCH_SIZE)
|
x_dependency, mlp_down_prefetch_size)
|
||||||
forward_context.layer_idx += 1
|
forward_context.layer_idx += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -104,12 +133,17 @@ def _maybe_prefetch_mlp_down_proj_impl_fake(
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
||||||
|
try:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
logger.info("Forward context is None, skipping the operation.")
|
||||||
|
return
|
||||||
|
|
||||||
if not forward_context.prefetch_mlp_enabled:
|
if not forward_context.prefetch_mlp_enabled:
|
||||||
return
|
return
|
||||||
if forward_context.prefetch_mlp_gate_up_proj or \
|
if forward_context.prefetch_mlp_gate_up_proj or \
|
||||||
forward_context.prefetch_mlp_down_proj:
|
forward_context.prefetch_mlp_down_proj:
|
||||||
prefetch_stream = get_forward_context().prefetch_stream
|
prefetch_stream = forward_context.prefetch_stream
|
||||||
# wait until prefetch done
|
# wait until prefetch done
|
||||||
torch.npu.current_stream().wait_stream(prefetch_stream)
|
torch.npu.current_stream().wait_stream(prefetch_stream)
|
||||||
forward_context.prefetch_mlp_gate_up_proj = False
|
forward_context.prefetch_mlp_gate_up_proj = False
|
||||||
|
|||||||
Reference in New Issue
Block a user