diff --git a/vllm_npu/worker/worker_v1.py b/vllm_npu/worker/worker_v1.py index 5e33b0d..7cb7bfe 100644 --- a/vllm_npu/worker/worker_v1.py +++ b/vllm_npu/worker/worker_v1.py @@ -13,7 +13,10 @@ from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import VllmConfig -from vllm.distributed import init_distributed_environment +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform @@ -92,6 +95,14 @@ class NPUWorker(WorkerBase): backend="hccl", ) + # Initialize TP / PP parallel groups + ensure_model_parallel_initialized( + tensor_model_parallel_size=( + self.parallel_config.tensor_parallel_size), + pipeline_model_parallel_size=( + self.parallel_config.pipeline_parallel_size), + ) + # Set random seed current_platform.seed_everything(self.model_config.seed)