diff --git a/run_inference_server.py b/run_inference_server.py index 9b0f10d53..5ae343998 100644 --- a/run_inference_server.py +++ b/run_inference_server.py @@ -4,16 +4,55 @@ import platform import argparse import subprocess +import re +import ipaddress -def run_command(command, shell=False): - """Run a system command and ensure it succeeds.""" +def validate_path(path): + """Validate that a path is safe and doesn't contain command injection attempts.""" + # Normalize the path to prevent directory traversal + normalized = os.path.normpath(path) + # Check for suspicious patterns that might indicate command injection + suspicious_patterns = [';', '&', '|', '$', '`', '\n', '\r', '>', '<', '(', ')'] + if any(char in normalized for char in suspicious_patterns): + raise ValueError(f"Invalid characters detected in path: {path}") + return normalized + +def validate_ip_address(ip): + """Validate that the IP address is valid.""" + try: + ipaddress.ip_address(ip) + return ip + except ValueError: + raise ValueError(f"Invalid IP address: {ip}") + +def validate_prompt(prompt): + """Validate prompt to prevent command injection.""" + # Check for suspicious patterns in prompt + suspicious_patterns = ['$(', '`', '|', ';', '&', '\n', '\r'] + if any(pattern in prompt for pattern in suspicious_patterns): + raise ValueError(f"Invalid characters detected in prompt") + return prompt + +def run_command(command): + """Run a system command safely without shell=True.""" try: - subprocess.run(command, shell=shell, check=True) + # Force shell=False to prevent command injection + subprocess.run(command, shell=False, check=True) except subprocess.CalledProcessError as e: print(f"Error occurred while running command: {e}") sys.exit(1) def run_server(): + # Validate all user inputs before using them + try: + validated_model = validate_path(args.model) + validated_host = validate_ip_address(args.host) + if args.prompt: + validated_prompt = validate_prompt(args.prompt) + except ValueError as e: + print(f"Validation error: {e}") + sys.exit(1) + build_dir = "build" if platform.system() == "Windows": server_path = os.path.join(build_dir, "bin", "Release", "llama-server.exe") @@ -24,23 +63,23 @@ def run_server(): command = [ f'{server_path}', - '-m', args.model, + '-m', validated_model, '-c', str(args.ctx_size), '-t', str(args.threads), '-n', str(args.n_predict), '-ngl', '0', '--temp', str(args.temperature), - '--host', args.host, + '--host', validated_host, '--port', str(args.port), '-cb' # Enable continuous batching ] if args.prompt: - command.extend(['-p', args.prompt]) + command.extend(['-p', validated_prompt]) # Note: -cnv flag is removed as it's not supported by the server - print(f"Starting server on {args.host}:{args.port}") + print(f"Starting server on {validated_host}:{args.port}") run_command(command) def signal_handler(sig, frame):