diff --git a/go_cli.py b/go_cli.py index 42296d4..d33355f 100644 --- a/go_cli.py +++ b/go_cli.py @@ -19,8 +19,9 @@ import fcntl import platform import requests +import click +import json import subprocess -import argparse import termios import urllib.parse import sys @@ -29,8 +30,7 @@ __version__ = "0.0.11" # current version SERVER_URL = "https://cli.gorilla-llm.com" -UPDATE_CHECK_FILE = os.path.expanduser("~/.gorilla-cli-last-update-check") -USERID_FILE = os.path.expanduser("~/.gorilla-cli-userid") +CONFIG_FILE = os.path.expanduser("~/.gorilla-cli-config.json") HISTORY_FILE = os.path.expanduser("~/.gorilla_cli_history") ISSUE_URL = f"https://github.com/gorilla-llm/gorilla-cli/issues/new" GORILLA_EMOJI = "🦍 " if go_questionary.try_encode_gorilla() else "" @@ -60,10 +60,6 @@ def get_git_email(): def get_system_info(): return platform.system() -def write_uid_to_file(uid): - with open(USERID_FILE, "w") as f: - f.write(uid) - def append_to_bash_history(selected_command): try: with open(os.path.expanduser("~/.bash_history"), "a") as history_file: @@ -103,10 +99,16 @@ def raise_issue(title, body): def check_for_updates(): # Check if a new version of gorilla-cli is available once a day try: - with open(UPDATE_CHECK_FILE, "r") as f: - last_check_date = datetime.datetime.strptime(f.read(), "%Y-%m-%d") - except FileNotFoundError: + with open(CONFIG_FILE, "r") as config_file: + config_json = json.load(config_file) + except Exception as e: + config_json = {} + + if "last_check_date" in config_json: + last_check_date = datetime.datetime.strptime(config_json["last_check_date"], "%Y-%m-%d") + else: last_check_date = datetime.datetime.now() - datetime.timedelta(days=1) + if datetime.datetime.now() - last_check_date >= datetime.timedelta(days=1): try: response = requests.get("https://pypi.org/pypi/gorilla-cli/json") @@ -116,11 +118,10 @@ def check_for_updates(): print(f"A new version is available: {latest_version}. Update with `pip install --upgrade gorilla-cli`") except Exception as e: print("Unable to check for updates:", e) - try: - with open(UPDATE_CHECK_FILE, "w") as f: - f.write(datetime.datetime.now().strftime("%Y-%m-%d")) - except Exception as e: - print("Unable to write update check file:", e) + + config_json["last_check_date"] = datetime.datetime.now().strftime("%Y-%m-%d") + with open(CONFIG_FILE, "w") as config_file: + json.dump(config_json, config_file) def get_user_id(): @@ -130,33 +131,59 @@ def get_user_id(): # for commercial serving. If you would like to request rate # limit increases for your GitHub handle, please raise an issue. try: - with open(USERID_FILE, "r") as f: - user_id = f.read().strip() - if not user_id: - user_id = generate_random_uid() - return user_id - except FileNotFoundError: + with open(CONFIG_FILE, "r") as config_file: + config_json = json.load(config_file) + except Exception as e: + config_json = {} + + if "user_id" in config_json: + return config_json["user_id"] + + user_id = get_legacy_user_id() + if user_id: + config_json["user_id"] = user_id + else: try: user_id = get_git_email() - print(WELCOME_TEXT) - response = input(f"Use your Github handle ({user_id}) as user id? [Y/n]: ").strip().lower() + response = ( + input(f"Use your Github handle ({user_id}) as user id? [Y/n]: ") + .strip() + .lower() + ) if response in ["n", "no"]: user_id = generate_random_uid() + + print(WELCOME_TEXT) + config_json["user_id"] = user_id + except Exception as e: + # If git not installed then generate and use a random user id print(f"Unable to import userid from Git. Git not installed or git user.email not configured.") print(f"Will use a random user-id. \n") user_id = generate_random_uid() - print(WELCOME_TEXT) - try: - write_uid_to_file(user_id) - return user_id - except Exception as e: - print(f"Unable to write userid to file: {e}") - raise_issue("Problem with userid file", f"Unable to write userid file: {e}") - print(f"Using a temporary UID {user_id} for now.") - return user_id + try: + with open(CONFIG_FILE, "w") as config_file: + config_json = json.dump(config_json, config_file) + except Exception as e: + print(f"Unable to write userid to config file: {e}") + raise_issue("Problem with cofnig file", f"Unable to write userid to cofnig file: {e}") + print(f"Using a temporary UID {user_id} for now.") + + return user_id +def get_legacy_user_id(): + # Previously, user id were not stored in config.json and were stored in a separate file. + # This function checks whether or not the user already have their user_id set + # Before the update to use config.json. + USERID_FILE = os.path.expanduser("~/.gorilla-cli-userid") + try: + with open(USERID_FILE, "r") as f: + user_id = f.read().strip() + return user_id + except: + return None + def format_command(input_str): """ Standardize commands to be stored with a newline @@ -184,76 +211,164 @@ def append_string_to_file_if_missing(file_path, target_string): file.write(target_string) -def main(): - def execute_command(cmd): - cmd = format_command(cmd) - process = subprocess.run(cmd, shell=True, stderr=subprocess.PIPE) - - save = not cmd.startswith(':') - if save: - append_string_to_file_if_missing(HISTORY_FILE, cmd) - - error_msg = process.stderr.decode("utf-8", "ignore") - if error_msg: - print(f"{error_msg}") - return error_msg - return str(process.returncode) - - def get_history_commands(history_file): - """ - Takes in history file - Returns None if file doesn't exist or empty - Returns list of last 10 history commands in the file if it exists - """ - if os.path.isfile(history_file): - with open(history_file, 'r') as history: - lines = history.readlines() - if not lines: - print("No command history.") - return lines[-HISTORY_LENGTH:] - else: - print("No command history.") - return - - args = sys.argv[1:] - user_input = " ".join(args) - user_id = get_user_id() - system_info = get_system_info() +def specify_models(ctx, param, file_path): + # By default, Gorilla-CLI combines the capabilities of multiple Language Learning Models. + # The specify_models command will make Gorilla exclusively utilize the inputted models. + if not file_path or ctx.resilient_parsing: + return + try: + with open(file_path, "r") as models_file: + models_json = json.load(models_file) + except Exception as e: + print("Failed to read from " + file_path) + ctx.exit() + try: + with open(CONFIG_FILE, "r") as config_file: + config_json = json.load(config_file) + except Exception as e: + config_json = {} + config_json["models"] = models_json["models"] + with open(CONFIG_FILE, "w") as config_file: + json.dump(config_json, config_file) + print("models set to: " + str(config_json["models"])) + ctx.exit() - # Parse command-line arguments - parser = argparse.ArgumentParser(description="Gorilla CLI Help Doc") - parser.add_argument("-p", "--history", action="store_true", help="Display command history") - parser.add_argument("command_args", nargs='*', help="Prompt to be inputted to Gorilla") - args = parser.parse_args() +def reset_models(ctx, param, value): + if not value or ctx.resilient_parsing: + return + try: + with open(CONFIG_FILE, "r") as config_file: + config_json = json.load(config_file) + if "models" in config_json: + del config_json["models"] + with open(CONFIG_FILE, "w") as config_file: + json.dump(config_json, config_file) + except Exception as e: + pass + ctx.exit() + + +def execute_command(cmd): + cmd = format_command(cmd) + process = subprocess.run(cmd, shell=True, stderr=subprocess.PIPE) + + save = not cmd.startswith(':') + if save: + append_string_to_file_if_missing(HISTORY_FILE, cmd) + + error_msg = process.stderr.decode("utf-8", "ignore") + if error_msg: + print(f"{error_msg}") + return error_msg + return str(process.returncode) + +def load_config(): + # Load the user's configuration file and perform any necessary checks + if os.path.isfile(CONFIG_FILE): + with open(CONFIG_FILE, "r") as config_file: + config_json = json.load(config_file) + return config_json + + +def print_version(ctx, param, value): + if not value or ctx.resilient_parsing: + return + click.echo(__version__) + ctx.exit() + +def get_history_commands(ctx, param, value): + """ + Takes in history file + Returns None if file doesn't exist or empty + Returns list of last 10 history commands in the file if it exists + """ + if not value or ctx.resilient_parsing: + return + history_file = HISTORY_FILE + if os.path.isfile(history_file): + with open(history_file, 'r') as history: + lines = history.readlines() + if not lines: + click.echo("No command history.") + else: + click.echo(lines[-HISTORY_LENGTH:]) + else: + click.echo("No command history.") + ctx.exit() + +def format_output_commands(commands): + for i, command in enumerate(commands): + if command[-1] == '\n': + commands[i] = command[:-1] + break + return commands + + + +@click.command() +@click.option('--user_id', '--u', default=get_user_id(), help="User id [default: 'git config --global user.email' OR random uuid]") +@click.option('--set_models', type=click.Path(), callback = specify_models, expose_value=True, + help = "Make Gorilla exclusively utilize the models in the json file specified") +@click.option('--reset_models', is_flag=True, callback =reset_models, expose_value=False, is_eager=True, + help = "Reset models configuration") +@click.option('--model', '-m', help = "Prompt Gorilla CLI to only use the specified model") +@click.option('--version', help = "Return the version of GORILLA_CLI", is_flag=True, callback= print_version, expose_value=False, is_eager=True) +@click.option('--history', help = "Display command history", is_flag=True, callback= get_history_commands, expose_value=False, is_eager=True) +@click.argument('prompt', nargs = -1) +def main( + user_id, + model, + set_models, + prompt, +): + check_for_updates() + config = load_config() + + if len(prompt) == 0: + print("error: prompt not found, see gorilla-cli usage below " + "➡️") + with click.Context(main) as ctx: + click.echo(main.get_help(ctx)) + return + + #Check if the user has specific model preference. + if model: + chosen_models = model + elif "models" in config: + chosen_models = config["models"] + else: + chosen_models = None # Generate a unique interaction ID interaction_id = str(uuid.uuid4()) - if args.history: - commands = get_history_commands(HISTORY_FILE) - else: - with Halo(text=f"{GORILLA_EMOJI}Loading", spinner="dots"): - try: - data_json = { - "user_id": user_id, - "user_input": user_input, - "interaction_id": interaction_id, - "system_info": system_info - } - response = requests.post( - f"{SERVER_URL}/commands_v2", json=data_json, timeout=30 - ) - commands = response.json() - except requests.exceptions.RequestException as e: - print("Server is unreachable.") - print("Try updating Gorilla-CLI with 'pip install --upgrade gorilla-cli'") - return - - check_for_updates() + args = sys.argv[1:] + user_input = " ".join(args) + system_info = get_system_info() + data_json = { + "user_id": user_id, + "user_input": user_input, + "interaction_id": interaction_id, + "system_info": system_info + } + if chosen_models: + data_json["models"] = chosen_models + print("Results are only chosen from the following LLM model(s): ", chosen_models) + + with Halo(text=f"{GORILLA_EMOJI}Loading", spinner="dots"): + try: + response = requests.post( + f"{SERVER_URL}/commands", json=data_json, timeout=30 + ) + commands = response.json() + except requests.exceptions.RequestException as e: + print("\nServer " + SERVER_URL + " is unreachable.") + print("Try updating Gorilla-CLI with 'pip install --upgrade gorilla-cli'") + return if commands: + commands = format_output_commands(commands) selected_command = go_questionary.select( "", choices=commands, instruction="Welcome to Gorilla. Use arrow keys to select. Ctrl-C to Exit" ).ask() @@ -261,13 +376,14 @@ def get_history_commands(history_file): if not selected_command: # happens when Ctrl-C is pressed return - exit_condition = execute_command(selected_command) # Append command to bash history if system_info == "Linux": append_to_bash_history(selected_command) prefill_shell_cmd(selected_command) + exit_condition = execute_command(selected_command) + # Commands failed / succeeded? try: response = requests.post( diff --git a/setup.py b/setup.py index fc81ba6..ebddd4b 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "requests", "halo", "prompt-toolkit", + "click" ], entry_points={ "console_scripts": [