1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
| import argparse import json import os import platform import sys import urllib.error import urllib.parse import urllib.request
def detect_shell(explicit_shell: str | None) -> str: if explicit_shell: return explicit_shell
system = platform.system().lower() if system == "windows": return "pwsh"
shell = os.environ.get("SHELL", "").strip() shell_name = os.path.basename(shell).lower() if shell_name in {"bash", "fish"}: return shell_name return "bash"
def normalize_base_url(base_url: str) -> str: parsed = urllib.parse.urlparse(base_url) if not parsed.scheme or not parsed.netloc: raise ValueError( "BASE_URL must include scheme and host, for example https://api.openai.com/v1" )
path = parsed.path.rstrip("/") if not path.endswith("/chat/completions"): if path.endswith("/v1"): path = f"{path}/chat/completions" else: path = f"{path}/chat/completions"
return urllib.parse.urlunparse( ( parsed.scheme, parsed.netloc, path, parsed.params, parsed.query, parsed.fragment, ) )
def build_messages(shell_name: str, task: str) -> list[dict[str, str]]: return [ { "role": "system", "content": ( "You convert a user's intent into exactly one shell command. " f"The target shell is {shell_name}. " "Return only the command, with no markdown, no explanation, no surrounding quotes. " "Prefer common built-ins and standard tools available in that shell environment. " "Do not return dangerous commands like deleting system files or recursive destructive actions " "unless the user explicitly asked for them." ), }, { "role": "user", "content": task, }, ]
def call_api( base_url: str, model: str, api_key: str, messages: list[dict[str, str]] ) -> str: body = json.dumps( { "model": model, "messages": messages, "temperature": 0.1, } ).encode("utf-8")
request = urllib.request.Request( normalize_base_url(base_url), data=body, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", }, method="POST", )
try: with urllib.request.urlopen(request) as response: payload = json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError as exc: details = exc.read().decode("utf-8", errors="replace") raise RuntimeError(f"API request failed: HTTP {exc.code}\n{details}") from exc except urllib.error.URLError as exc: raise RuntimeError(f"API connection failed: {exc.reason}") from exc
try: return payload["choices"][0]["message"]["content"] except (KeyError, IndexError, TypeError) as exc: raise RuntimeError( f"Unexpected API response format: {json.dumps(payload, ensure_ascii=False)}" ) from exc
def clean_command(text: str) -> str: result = text.strip() if result.startswith("```"): lines = result.splitlines() if ( len(lines) >= 3 and lines[0].startswith("```") and lines[-1].strip() == "```" ): result = "\n".join(lines[1:-1]).strip() return result.strip("` \r\n")
def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Turn a natural language request into one shell command for the current shell." ) parser.add_argument( "task", nargs="*", help="Task description for the model. Reads from stdin if omitted.", ) parser.add_argument( "--base-url", required=True, help="BASE_URL for an OpenAI-compatible API." ) parser.add_argument("--model", required=True, help="Model name to call.") parser.add_argument("--api-key", required=True, help="API key.") parser.add_argument( "--shell", choices=["pwsh", "bash", "fish"], help="Override the target shell. Defaults to auto-detection.", ) return parser.parse_args()
def read_task(args: argparse.Namespace) -> str: if args.task: return " ".join(args.task).strip()
if not sys.stdin.isatty(): return sys.stdin.read().strip()
raise ValueError("Provide a task as arguments or via stdin.")
def main() -> int: args = parse_args()
try: task = read_task(args) if not task: raise ValueError("Task text must not be empty.")
shell_name = detect_shell(args.shell) command = call_api( base_url=args.base_url, model=args.model, api_key=args.api_key, messages=build_messages(shell_name, task), ) print(clean_command(command)) return 0 except Exception as exc: print(str(exc), file=sys.stderr) return 1
if __name__ == "__main__": raise SystemExit(main())
|