mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
"""
|
|
Quick offline inference demo — no server needed.
|
|
|
|
Usage:
|
|
python demo.py
|
|
python demo.py --model /path/to/model
|
|
python demo.py --prompt "What is AI?"
|
|
"""
|
|
|
|
import argparse
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="vLLM NPU offline demo")
|
|
parser.add_argument(
|
|
"--model",
|
|
default="/workspace/mnt/vllm_ascend/Qwen2.5-7B-Instruct",
|
|
help="Path to the model",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
default="你好,请简单介绍一下自己",
|
|
help="User prompt",
|
|
)
|
|
parser.add_argument("--max-tokens", type=int, default=128)
|
|
parser.add_argument("--max-model-len", type=int, default=512)
|
|
parser.add_argument("--dtype", default="float16")
|
|
parser.add_argument("--block-size", type=int, default=128)
|
|
args = parser.parse_args()
|
|
|
|
print(f"Loading model: {args.model}")
|
|
llm = LLM(
|
|
model=args.model,
|
|
dtype=args.dtype,
|
|
max_model_len=args.max_model_len,
|
|
block_size=args.block_size,
|
|
trust_remote_code=True,
|
|
enforce_eager=True, # skip graph capture for debugging
|
|
)
|
|
|
|
# Build chat-format prompt for Qwen2.5
|
|
messages = [{"role": "user", "content": args.prompt}]
|
|
tokenizer = llm.get_tokenizer()
|
|
prompt = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
sampling_params = SamplingParams(
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
max_tokens=args.max_tokens,
|
|
)
|
|
|
|
print(f"\nPrompt: {args.prompt}")
|
|
print("-" * 60)
|
|
|
|
outputs = llm.generate([prompt], sampling_params)
|
|
|
|
for output in outputs:
|
|
generated = output.outputs[0].text
|
|
print(f"Response:\n{generated}")
|
|
print("-" * 60)
|
|
print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|