mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
50
vllm_npu/sample/logits_processor/__init__.py
Normal file
50
vllm_npu/sample/logits_processor/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample import logits_processor
|
||||
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
||||
MinTokensLogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
from vllm.v1.sample.logits_processor.state import LogitsProcessors
|
||||
|
||||
from vllm_npu.sample.logits_processor.builtin import \
|
||||
AscendMinPLogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a pooling model
|
||||
# and custom logitsproces
|
||||
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
|
||||
" logits processors.")
|
||||
|
||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||
MinTokensLogitsProcessor,
|
||||
LogitBiasLogitsProcessor,
|
||||
AscendMinPLogitsProcessor,
|
||||
]
|
||||
|
||||
|
||||
def build_logitsprocs(
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
is_pin_memory: bool,
|
||||
is_pooling_model: bool,
|
||||
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
|
||||
) -> LogitsProcessors:
|
||||
if is_pooling_model:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
||||
logger.debug("Skipping logits processor loading because pooling models"
|
||||
" do not support logits processors.")
|
||||
return LogitsProcessors()
|
||||
custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs(
|
||||
custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
|
||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
||||
35
vllm_npu/sample/logits_processor/builtin.py
Normal file
35
vllm_npu/sample/logits_processor/builtin.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.sample.logits_processor import MinPLogitsProcessor
|
||||
|
||||
|
||||
class AscendMinPLogitsProcessor(MinPLogitsProcessor):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
super().__init__(vllm_config, device, is_pin_memory)
|
||||
|
||||
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
if decode_max_num_seqs != 0:
|
||||
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=is_pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
|
||||
self.use_double_tensor = torch.device(device).type != "cpu"
|
||||
|
||||
if self.use_double_tensor:
|
||||
# Pre-allocated device tensor
|
||||
self.min_p_device: torch.Tensor = torch.empty(
|
||||
(max_num_reqs, ), dtype=torch.float32, device=device)
|
||||
else:
|
||||
self.min_p_device = self.min_p_cpu_tensor
|
||||
# Current slice of the device tensor
|
||||
self.min_p: torch.Tensor = self.min_p_device[:0]
|
||||
Reference in New Issue
Block a user