init
This commit is contained in:
63
worker/multi_step_neuronx_distributed_model_runner.py
Normal file
63
worker/multi_step_neuronx_distributed_model_runner.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.neuronx_distributed_model_runner import (
|
||||
NeuronxDistributedModelRunner)
|
||||
|
||||
|
||||
class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
|
||||
"""A model runner for multi-step decoding using the
|
||||
neuronx-distributed-inference framework"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
from vllm.model_executor.model_loader.neuronx_distributed import (
|
||||
get_neuron_speculation_model)
|
||||
self.model = get_neuron_speculation_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
speculation_config=self.speculative_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
sampling_params = torch.tensor([[
|
||||
seq_group.sampling_params.top_k,
|
||||
seq_group.sampling_params.top_p,
|
||||
seq_group.sampling_params.temperature,
|
||||
] for seq_group in model_input.sampling_metadata.seq_groups])
|
||||
|
||||
logits = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
Reference in New Issue
Block a user