[Bugfix] rename enable_flash_comm_v1 back to enable_sp (#6883)
### What this PR does / why we need it?
PR #5632 introduced a bug by replacing some branches gated by enable_sp
with enable_flash_comm_v1. As a result, when enable_shared_expert_dp is
enabled alone (i.e., VLLM_ASCEND_ENABLE_FLASHCOMM1=0 and
VLLM_ASCEND_ENABLE_FLASHCOMM=0), the behavior becomes inconsistent with
the previous logic and leads to accuracy issues. This PR restores the
original enable_sp-based branching to recover expected behavior and
accuracy.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
#### 1. start server
``` bash
vllm serve /home/weights/DeepSeek-V2-Lite-W8A8/ \
--port 8001 \
--served-model-name auto \
--max-model-len 1024 \
--enforce-eager \
--tensor-parallel-size 2 \
--data-parallel-size 2 \
--gpu-memory-utilization 0.9 \
--enable-expert-parallel \
--additional-config '{"enable_shared_expert_dp": true}'
```
#### 2. curl
```bash
curl -s http://localhost:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "auto",
"messages": [
{"role": "user", "content": "Hello. I have a question. Who are you?"}
],
"max_tokens": 10,
"temperature": 0.0,
"ignore_eos_token": true
}'
```
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import vllm_ascend.envs as envs_ascend
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
AscendDeviceType,
|
AscendDeviceType,
|
||||||
enable_flash_comm_v1,
|
enable_sp,
|
||||||
flashcomm2_enable,
|
flashcomm2_enable,
|
||||||
get_ascend_device_type,
|
get_ascend_device_type,
|
||||||
has_layer_idx,
|
has_layer_idx,
|
||||||
@@ -92,14 +92,14 @@ def set_ascend_forward_context(
|
|||||||
# main model and drafter model may have different architecture
|
# main model and drafter model may have different architecture
|
||||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||||
if is_context_moe_model:
|
if is_context_moe_model:
|
||||||
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None
|
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||||
mmrs_fusion = False
|
mmrs_fusion = False
|
||||||
elif is_draft_model:
|
elif is_draft_model:
|
||||||
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
|
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
|
||||||
# Disable it to avoid more problems.
|
# Disable it to avoid more problems.
|
||||||
flash_comm_v1_enabled = False
|
flash_comm_v1_enabled = False
|
||||||
else:
|
else:
|
||||||
flash_comm_v1_enabled = enable_flash_comm_v1() and num_tokens is not None and num_tokens > 1000
|
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||||
forward_context.mmrs_fusion = mmrs_fusion
|
forward_context.mmrs_fusion = mmrs_fusion
|
||||||
forward_context.num_tokens = num_tokens
|
forward_context.num_tokens = num_tokens
|
||||||
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
|
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
||||||
from vllm_ascend.utils import enable_flash_comm_v1, maybe_trans_nz
|
from vllm_ascend.utils import enable_sp, maybe_trans_nz
|
||||||
|
|
||||||
|
|
||||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||||
@@ -240,7 +240,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
|||||||
disable_tp: bool = False,
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
|
# TODO(kunpengW-code): Specifying the prefix in linear layers of some models in the vLLM.
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
unique_prefix = prefix
|
unique_prefix = prefix
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
|||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
enable_dsa_cp,
|
enable_dsa_cp,
|
||||||
enable_dsa_cp_with_layer_shard,
|
enable_dsa_cp_with_layer_shard,
|
||||||
enable_flash_comm_v1,
|
enable_sp,
|
||||||
flashcomm2_enable,
|
flashcomm2_enable,
|
||||||
get_flashcomm2_reorgnized_batch_ids,
|
get_flashcomm2_reorgnized_batch_ids,
|
||||||
get_weight_prefetch_method,
|
get_weight_prefetch_method,
|
||||||
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
|
|||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||||
|
|
||||||
# Trigger async broadcast before matmul to overlap communication.
|
# Trigger async broadcast before matmul to overlap communication.
|
||||||
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
|
|||||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||||
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
|
||||||
return Flashcomm2OshardQKVParallelOp(layer)
|
return Flashcomm2OshardQKVParallelOp(layer)
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
if "shared_expert" in prefix:
|
if "shared_expert" in prefix:
|
||||||
return None
|
return None
|
||||||
sp_column_prefix = [
|
sp_column_prefix = [
|
||||||
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
|
|||||||
if flashcomm2_enable():
|
if flashcomm2_enable():
|
||||||
if "o_proj" in prefix or "out_proj" in prefix:
|
if "o_proj" in prefix or "out_proj" in prefix:
|
||||||
return Flashcomm2OProjRowParallelOp(layer)
|
return Flashcomm2OProjRowParallelOp(layer)
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
if "shared_expert" in prefix:
|
if "shared_expert" in prefix:
|
||||||
return None
|
return None
|
||||||
sp_row_prefixes = [
|
sp_row_prefixes = [
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from vllm_ascend.utils import (
|
|||||||
COMPRESSED_TENSORS_METHOD,
|
COMPRESSED_TENSORS_METHOD,
|
||||||
AscendDeviceType,
|
AscendDeviceType,
|
||||||
check_kv_extra_config,
|
check_kv_extra_config,
|
||||||
enable_sp,
|
|
||||||
flashcomm2_enable,
|
flashcomm2_enable,
|
||||||
get_ascend_device_type,
|
get_ascend_device_type,
|
||||||
is_moe_model,
|
is_moe_model,
|
||||||
@@ -48,7 +47,7 @@ from vllm_ascend.utils import (
|
|||||||
update_aclgraph_sizes,
|
update_aclgraph_sizes,
|
||||||
update_cudagraph_capture_sizes,
|
update_cudagraph_capture_sizes,
|
||||||
is_310p,
|
is_310p,
|
||||||
enable_flash_comm_v1,
|
enable_sp,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -402,7 +401,7 @@ class NPUPlatform(Platform):
|
|||||||
)
|
)
|
||||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||||
|
|
||||||
if enable_flash_comm_v1():
|
if enable_sp(vllm_config):
|
||||||
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
||||||
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
||||||
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
||||||
|
|||||||
@@ -719,15 +719,6 @@ def matmul_allreduce_enable() -> bool:
|
|||||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||||
|
|
||||||
|
|
||||||
def enable_flash_comm_v1():
|
|
||||||
return (
|
|
||||||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
|
||||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
|
||||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
|
||||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_sp_by_pass(vllm_config: VllmConfig):
|
def enable_sp_by_pass(vllm_config: VllmConfig):
|
||||||
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
return not vllm_config.model_config.enforce_eager and vllm_config.compilation_config.pass_config.enable_sp
|
||||||
|
|
||||||
@@ -739,7 +730,12 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
_ENABLE_SP = enable_sp_by_pass(vllm_config) or enable_flash_comm_v1()
|
_ENABLE_SP = (
|
||||||
|
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||||
|
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||||
|
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||||
|
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||||
|
)
|
||||||
|
|
||||||
if not _ENABLE_SP and enable_shared_expert_dp:
|
if not _ENABLE_SP and enable_shared_expert_dp:
|
||||||
_ENABLE_SP = True
|
_ENABLE_SP = True
|
||||||
@@ -1104,7 +1100,7 @@ def enable_dsa_cp() -> bool:
|
|||||||
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||||
vllm_config.model_config.hf_text_config, "index_topk"
|
vllm_config.model_config.hf_text_config, "index_topk"
|
||||||
)
|
)
|
||||||
return bool(is_ds_v32 and enable_flash_comm_v1())
|
return bool(is_ds_v32 and enable_sp())
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
|||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
check_gdn_layer,
|
check_gdn_layer,
|
||||||
enable_flash_comm_v1,
|
|
||||||
enable_sp,
|
enable_sp,
|
||||||
|
enable_sp_by_pass,
|
||||||
is_drafter_moe_model,
|
is_drafter_moe_model,
|
||||||
is_moe_model,
|
is_moe_model,
|
||||||
lmhead_tp_enable,
|
lmhead_tp_enable,
|
||||||
@@ -1745,7 +1745,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# Pad tokens to multiple of tensor_parallel_size when
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
# enabled collective fusion for SP
|
# enabled collective fusion for SP
|
||||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||||
if enable_sp(self.vllm_config):
|
if enable_sp(self.vllm_config) or enable_sp_by_pass(self.vllm_config):
|
||||||
return round_up(num_scheduled_tokens, tp_size)
|
return round_up(num_scheduled_tokens, tp_size)
|
||||||
return num_scheduled_tokens
|
return num_scheduled_tokens
|
||||||
|
|
||||||
@@ -2300,7 +2300,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
|
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
|
||||||
# to incorrect memory estimation and potentially causing OOM.
|
# to incorrect memory estimation and potentially causing OOM.
|
||||||
intermediate_tokens = num_tokens_padded
|
intermediate_tokens = num_tokens_padded
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
|
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
|
||||||
if self.intermediate_tensors is None:
|
if self.intermediate_tensors is None:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
|||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
AscendDeviceType,
|
AscendDeviceType,
|
||||||
check_ascend_device_type,
|
check_ascend_device_type,
|
||||||
enable_flash_comm_v1,
|
enable_sp,
|
||||||
get_ascend_device_type,
|
get_ascend_device_type,
|
||||||
register_ascend_customop,
|
register_ascend_customop,
|
||||||
)
|
)
|
||||||
@@ -376,7 +376,7 @@ class NPUWorker(WorkerBase):
|
|||||||
if forward_pass and not get_pp_group().is_first_rank:
|
if forward_pass and not get_pp_group().is_first_rank:
|
||||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||||
# it will conflict with the all-gather operation in flashcomm1.
|
# it will conflict with the all-gather operation in flashcomm1.
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
all_gather_group = None
|
all_gather_group = None
|
||||||
else:
|
else:
|
||||||
all_gather_group = get_tp_group()
|
all_gather_group = get_tp_group()
|
||||||
@@ -393,7 +393,7 @@ class NPUWorker(WorkerBase):
|
|||||||
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
|
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
|
||||||
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
||||||
# it will conflict with the all-gather operation in flashcomm1.
|
# it will conflict with the all-gather operation in flashcomm1.
|
||||||
if enable_flash_comm_v1():
|
if enable_sp():
|
||||||
all_gather_group = None
|
all_gather_group = None
|
||||||
else:
|
else:
|
||||||
all_gather_group = get_tp_group()
|
all_gather_group = get_tp_group()
|
||||||
|
|||||||
Reference in New Issue
Block a user