diff --git a/vllm_npu/worker/worker_v1.py b/vllm_npu/worker/worker_v1.py index 7cb7bfe..772e524 100644 --- a/vllm_npu/worker/worker_v1.py +++ b/vllm_npu/worker/worker_v1.py @@ -184,6 +184,12 @@ class NPUWorker(WorkerBase): def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Store the number of KV cache blocks.""" + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate KV caches on NPU.""" self.model_runner.initialize_kv_cache(kv_cache_config)