mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
feat: Add vLLM NPU offline inference demo script.
This commit is contained in:
69
demo.py
Normal file
69
demo.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Quick offline inference demo — no server needed.
|
||||
|
||||
Usage:
|
||||
python demo.py
|
||||
python demo.py --model /path/to/model
|
||||
python demo.py --prompt "What is AI?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="vLLM NPU offline demo")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="/workspace/mnt/vllm_ascend/Qwen2.5-7B-Instruct",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
default="你好,请简单介绍一下自己",
|
||||
help="User prompt",
|
||||
)
|
||||
parser.add_argument("--max-tokens", type=int, default=128)
|
||||
parser.add_argument("--max-model-len", type=int, default=512)
|
||||
parser.add_argument("--dtype", default="float16")
|
||||
parser.add_argument("--block-size", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading model: {args.model}")
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
dtype=args.dtype,
|
||||
max_model_len=args.max_model_len,
|
||||
block_size=args.block_size,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True, # skip graph capture for debugging
|
||||
)
|
||||
|
||||
# Build chat-format prompt for Qwen2.5
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
tokenizer = llm.get_tokenizer()
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=args.max_tokens,
|
||||
)
|
||||
|
||||
print(f"\nPrompt: {args.prompt}")
|
||||
print("-" * 60)
|
||||
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
generated = output.outputs[0].text
|
||||
print(f"Response:\n{generated}")
|
||||
print("-" * 60)
|
||||
print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user