[Refactor] Import global var form vllm instead of overwirte it (#5469)
### What this PR does / why we need it?
Import global var form vllm instead of overwirte it, so that we could
use the correct global variant value
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -1,40 +0,0 @@
|
|||||||
import torch
|
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
from vllm.config import SchedulerConfig, VllmConfig
|
|
||||||
|
|
||||||
from tests.ut.base import PytestBase
|
|
||||||
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class TestMinPLogitsProcessorInitFunc(PytestBase):
|
|
||||||
|
|
||||||
def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
|
|
||||||
device_cpu = torch.device("cpu")
|
|
||||||
device_npu = torch.device("npu")
|
|
||||||
is_pin_memory = False
|
|
||||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
|
||||||
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
|
|
||||||
mock_scheduler_config.decode_max_num_seqs = 0
|
|
||||||
mock_scheduler_config.max_num_seqs = 128
|
|
||||||
mock_vllm_config.scheduler_config = mock_scheduler_config
|
|
||||||
# torch.zeros/torch.empty returns error on online ut machine, so mock it
|
|
||||||
mock_tensor = torch.zeros((256, ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
pin_memory=False)
|
|
||||||
mocker.patch("torch.zeros", return_value=mock_tensor)
|
|
||||||
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
|
|
||||||
mocker.patch("torch.empty", return_value=mock_empty_tensor)
|
|
||||||
|
|
||||||
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
|
|
||||||
is_pin_memory)
|
|
||||||
|
|
||||||
assert processor_cpu.min_p is not None
|
|
||||||
assert processor_cpu.use_double_tensor is False
|
|
||||||
assert processor_cpu.min_p_cpu.shape[0] == 256
|
|
||||||
|
|
||||||
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
|
|
||||||
is_pin_memory)
|
|
||||||
|
|
||||||
assert processor_cpu.min_p is not None
|
|
||||||
assert processor_cpu.use_double_tensor is True
|
|
||||||
assert processor_cpu.min_p_cpu.shape[0] == 256
|
|
||||||
@@ -32,6 +32,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.distributed.mooncake_connector import GET_META_MSG
|
||||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||||
from vllm_ascend.distributed.utils import (align_memory,
|
from vllm_ascend.distributed.utils import (align_memory,
|
||||||
get_transfer_timeout_value,
|
get_transfer_timeout_value,
|
||||||
@@ -44,7 +45,6 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
GET_META_MSG = b"get_meta_msg"
|
|
||||||
DONE_SENDING_MSG = b"done_sending_msg"
|
DONE_SENDING_MSG = b"done_sending_msg"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
PAD_SLOT_ID = -1
|
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_ref(
|
def causal_conv1d_ref(
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from vllm.model_executor.models import bert
|
|||||||
|
|
||||||
# aclgraph does not support shift operator for now
|
# aclgraph does not support shift operator for now
|
||||||
# TODO: revert me when aclgraph supports shift operator
|
# TODO: revert me when aclgraph supports shift operator
|
||||||
TOKEN_TYPE_SHIFT = 30
|
|
||||||
TOKEN_TYPE_MULTIPLIER = 1 << 30
|
TOKEN_TYPE_MULTIPLIER = 1 << 30
|
||||||
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
|
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
import itertools
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import TYPE_CHECKING, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.v1.sample import logits_processor
|
|
||||||
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
|
||||||
MinTokensLogitsProcessor)
|
|
||||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
|
||||||
from vllm.v1.sample.logits_processor.state import LogitsProcessors
|
|
||||||
|
|
||||||
from vllm_ascend.sample.logits_processor.builtin import \
|
|
||||||
AscendMinPLogitsProcessor
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# Error message when the user tries to initialize vLLM with a pooling model
|
|
||||||
# and custom logitsproces
|
|
||||||
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
|
|
||||||
" logits processors.")
|
|
||||||
|
|
||||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
|
||||||
MinTokensLogitsProcessor,
|
|
||||||
LogitBiasLogitsProcessor,
|
|
||||||
AscendMinPLogitsProcessor,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_logitsprocs(
|
|
||||||
vllm_config: "VllmConfig",
|
|
||||||
device: torch.device,
|
|
||||||
is_pin_memory: bool,
|
|
||||||
is_pooling_model: bool,
|
|
||||||
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
|
|
||||||
) -> LogitsProcessors:
|
|
||||||
if is_pooling_model:
|
|
||||||
if custom_logitsprocs:
|
|
||||||
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
|
||||||
logger.debug("Skipping logits processor loading because pooling models"
|
|
||||||
" do not support logits processors.")
|
|
||||||
return LogitsProcessors()
|
|
||||||
custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs(
|
|
||||||
custom_logitsprocs)
|
|
||||||
return LogitsProcessors(
|
|
||||||
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
|
|
||||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
import torch
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.v1.sample.logits_processor import MinPLogitsProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class AscendMinPLogitsProcessor(MinPLogitsProcessor):
|
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
|
||||||
is_pin_memory: bool):
|
|
||||||
super().__init__(vllm_config, device, is_pin_memory)
|
|
||||||
|
|
||||||
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
|
||||||
'decode_max_num_seqs', 0)
|
|
||||||
if decode_max_num_seqs != 0:
|
|
||||||
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
|
|
||||||
decode_max_num_seqs)
|
|
||||||
|
|
||||||
self.min_p_count: int = 0
|
|
||||||
|
|
||||||
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=is_pin_memory)
|
|
||||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
|
||||||
|
|
||||||
self.use_double_tensor = torch.device(device).type != "cpu"
|
|
||||||
|
|
||||||
if self.use_double_tensor:
|
|
||||||
# Pre-allocated device tensor
|
|
||||||
self.min_p_device: torch.Tensor = torch.empty(
|
|
||||||
(max_num_reqs, ), dtype=torch.float32, device=device)
|
|
||||||
else:
|
|
||||||
self.min_p_device = self.min_p_cpu_tensor
|
|
||||||
# Current slice of the device tensor
|
|
||||||
self.min_p: torch.Tensor = self.min_p_device[:0]
|
|
||||||
|
|
||||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
if not self.min_p_count:
|
|
||||||
return logits
|
|
||||||
# Convert logits to probability distribution
|
|
||||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
|
||||||
# Calculate maximum probabilities per sequence
|
|
||||||
max_probabilities = torch.amax(probability_values,
|
|
||||||
dim=-1,
|
|
||||||
keepdim=True)
|
|
||||||
# Adjust min_p
|
|
||||||
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
|
||||||
# Identify valid tokens using threshold comparison
|
|
||||||
invalid_token_mask = probability_values < adjusted_min_p
|
|
||||||
# Apply mask using boolean indexing
|
|
||||||
logits.masked_fill_(invalid_token_mask, -float('inf'))
|
|
||||||
return logits
|
|
||||||
@@ -4,7 +4,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from vllm.triton_utils import HAS_TRITON, triton
|
from vllm.triton_utils import HAS_TRITON, triton
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE,
|
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN,
|
||||||
|
PLACEHOLDER_TOKEN_ID,
|
||||||
generate_uniform_probs)
|
generate_uniform_probs)
|
||||||
|
|
||||||
from vllm_ascend.ops.triton.reject_sample import (
|
from vllm_ascend.ops.triton.reject_sample import (
|
||||||
@@ -13,11 +14,6 @@ from vllm_ascend.ops.triton.reject_sample import (
|
|||||||
sample_recovered_tokens_kernel)
|
sample_recovered_tokens_kernel)
|
||||||
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
||||||
|
|
||||||
PLACEHOLDER_TOKEN_ID = -1
|
|
||||||
# Maximum number of speculative draft tokens allowed per request in a single
|
|
||||||
# step. This value is chosen to be large enough to handle typical use cases.
|
|
||||||
MAX_SPEC_LEN = 32
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_constraints(
|
def apply_sampling_constraints(
|
||||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
|||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
|
||||||
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
|
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
@@ -40,8 +41,6 @@ from vllm_ascend.ops.triton.spec_decode.utils import \
|
|||||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
|
||||||
|
|
||||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
|||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
@@ -18,8 +19,6 @@ from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
|||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(EagleProposer):
|
class MtpProposer(EagleProposer):
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|||||||
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
||||||
SamplerOutput,
|
SamplerOutput,
|
||||||
make_empty_encoder_model_runner_output)
|
make_empty_encoder_model_runner_output)
|
||||||
|
from vllm.v1.sample.logits_processor import build_logitsprocs
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@@ -98,7 +99,6 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
|||||||
from vllm_ascend.eplb.utils import model_register
|
from vllm_ascend.eplb.utils import model_register
|
||||||
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
||||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
|
||||||
from vllm_ascend.sample.sampler import AscendSampler
|
from vllm_ascend.sample.sampler import AscendSampler
|
||||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
|
|||||||
Reference in New Issue
Block a user