diff --git a/fire/completion.py b/fire/completion.py index 8b74776f..e766a89f 100644 --- a/fire/completion.py +++ b/fire/completion.py @@ -46,13 +46,11 @@ def _BashScript(name, commands, default_options=None): A string which is the Bash script. Source the bash script to enable tab completion in Bash. """ + default_options = default_options or set() - options_map = collections.defaultdict(lambda: copy.copy(default_options)) - for command in commands: - start = (name + ' ' + ' '.join(command[:-1])).strip() - completion = _FormatForCommand(command[-1]) - options_map[start].add(completion) - options_map[start.replace('_', '-')].add(completion) + global_options, options_map, subcommands_map = _GetMaps( + name, commands, default_options + ) bash_completion_template = """# bash completion support for {name} # DO NOT EDIT. @@ -60,41 +58,128 @@ def _BashScript(name, commands, default_options=None): _complete-{identifier}() {{ - local start cur opts + local cur prev opts lastcommand COMPREPLY=() - start="${{COMP_WORDS[@]:0:COMP_CWORD}}" + prev="${{COMP_WORDS[COMP_CWORD-1]}}" cur="${{COMP_WORDS[COMP_CWORD]}}" + lastcommand=$(get_lastcommand) opts="{default_options}" + GLOBAL_OPTIONS="{global_options}" -{start_checks} +{checks} COMPREPLY=( $(compgen -W "${{opts}}" -- ${{cur}}) ) return 0 }} +get_lastcommand() +{{ + local lastcommand i + + lastcommand= + for ((i=0; i < ${{#COMP_WORDS[@]}}; ++i)); do + if [[ ${{COMP_WORDS[i]}} != -* ]] && [[ -n ${{COMP_WORDS[i]}} ]] && [[ + ${{COMP_WORDS[i]}} != $cur ]]; then + lastcommand=${{COMP_WORDS[i]}} + fi + done + + echo $lastcommand +}} + +filter_options() +{{ + local opts + opts="" + for opt in "$@" + do + if ! option_already_entered $opt; then + opts="$opts $opt" + fi + done + + echo $opts +}} + +option_already_entered() +{{ + local opt + for opt in ${{COMP_WORDS[@]:0:COMP_CWORD}} + do + if [ $1 == $opt ]; then + return 0 + fi + done + return 1 +}} + +is_prev_global() +{{ + local opt + for opt in $GLOBAL_OPTIONS + do + if [ $opt == $prev ]; then + return 0 + fi + done + return 1 +}} + complete -F _complete-{identifier} {command} """ - start_check_template = """ - if [[ "$start" == "{start}" ]] ; then - opts="{completions}" - fi""" - - start_checks = '\n'.join( - start_check_template.format( - start=start, - completions=' '.join(sorted(options_map[start])) + + check_wrapper = """ + case "${{lastcommand}}" in + {lastcommand_checks} + esac""" + + lastcommand_check_template = """ + {command}) + {opts_assignment} + opts=$(filter_options $opts) + ;;""" + + opts_assignment_subcommand_template = """ + if is_prev_global; then + opts="${{GLOBAL_OPTIONS}}" + else + opts="{options} ${{GLOBAL_OPTIONS}}" + fi""" + + opts_assignment_main_command_template = """ + opts="{options} ${{GLOBAL_OPTIONS}}" """ + + def _GetOptsAssignmentTemplate(command): + if command == name: + return opts_assignment_main_command_template + else: + return opts_assignment_subcommand_template + + lastcommand_checks = '\n'.join( + lastcommand_check_template.format( + command=command, + opts_assignment=_GetOptsAssignmentTemplate(command).format( + options=' '.join(sorted( + options_map[command].union(subcommands_map[command]) + )), + ), ) - for start in options_map + for command in set(subcommands_map.keys()).union(set(options_map.keys())) + ) + + checks = check_wrapper.format( + lastcommand_checks=lastcommand_checks, ) return ( bash_completion_template.format( name=name, command=name, - start_checks=start_checks, + checks=checks, default_options=' '.join(default_options), - identifier=name.replace('/', '').replace('.', '').replace(',', '') + identifier=name.replace('/', '').replace('.', '').replace(',', ''), + global_options=' '.join(global_options), ) ) @@ -114,44 +199,85 @@ def _FishScript(name, commands, default_options=None): completion in Fish. """ default_options = default_options or set() - options_map = collections.defaultdict(lambda: copy.copy(default_options)) - for command in commands: - start = (name + ' ' + ' '.join(command[:-1])).strip() - completion = _FormatForCommand(command[-1]) - options_map[start].add(completion) - options_map[start.replace('_', '-')].add(completion) + global_options, options_map, subcommands_map = _GetMaps( + name, commands, default_options + ) + fish_source = """function __fish_using_command set cmd (commandline -opc) - if [ (count $cmd) -eq (count $argv) ] - for i in (seq (count $argv)) - if [ $cmd[$i] != $argv[$i] ] + for i in (seq (count $cmd) 1) + switch $cmd[$i] + case "-*" + case "*" + if [ $cmd[$i] = $argv[1] ] + return 0 + else return 1 end end - return 0 end return 1 end + +function __option_entered_check + set cmd (commandline -opc) + for i in (seq (count $cmd)) + switch $cmd[$i] + case "-*" + if [ $cmd[$i] = $argv[1] ] + return 1 + end + end + end + return 0 +end + +function __is_prev_global + set cmd (commandline -opc) + set global_options {global_options} + set prev (count $cmd) + + for opt in $global_options + if [ "--$opt" = $cmd[$prev] ] + echo $prev + return 0 + end + end + return 1 +end + """ - subcommand_template = ("complete -c {name} -n '__fish_using_command {start}' " - "-f -a {subcommand}\n") + + subcommand_template = ("complete -c {name} -n '__fish_using_command " + "{command}' -f -a {subcommand}\n") flag_template = ("complete -c {name} -n " - "'__fish_using_command {start}' -l {option}\n") - for start in options_map: - for option in sorted(options_map[start]): - if option.startswith('--'): - fish_source += flag_template.format( - name=name, - start=start, - option=option[2:] - ) - else: - fish_source += subcommand_template.format( - name=name, - start=start, - subcommand=option - ) - return fish_source + "'__fish_using_command {command};{prev_global_check} and " + "__option_entered_check --{option}' -l {option}\n") + + prev_global_check = " and __is_prev_global;" + for command in set(subcommands_map.keys()).union(set(options_map.keys())): + for subcommand in subcommands_map[command]: + fish_source += subcommand_template.format( + name=name, + command=command, + subcommand=subcommand, + ) + + for option in options_map[command].union(global_options): + check_needed = command != name + fish_source += flag_template.format( + name=name, + command=command, + prev_global_check=prev_global_check if check_needed else "", + option=option.lstrip("--"), + ) + + return fish_source.format( + global_options=' '.join( + '"{option}"'.format(option=option) + for option in global_options + ) + ) def _IncludeMember(name, verbose): @@ -302,3 +428,51 @@ def _Commands(component, depth=3): for command in _Commands(member, depth - 1): yield (member_name,) + command + + +def _IsOption(arg): + return arg.startswith('-') + +def _GetMaps(name, commands, default_options): + """Returns sets of subcommands and options for each command. + + Args: + name: The first token in the commands, also the name of the command. + commands: A list of all possible commands that tab completion can complete + to. Each command is a list or tuple of the string tokens that make up + that command. + default_options: A dict of options that can be used with any command. Use + this if there are flags that can always be appended to a command. + Returns: + global_options: A set of all options of the first token of the command. + subcommands_map: A dict storing set of subcommands for each + command/subcommand. + options_map: A dict storing set of options for each subcommand. + """ + + global_options = copy.copy(default_options) + options_map = collections.defaultdict(lambda: copy.copy(default_options)) + subcommands_map = collections.defaultdict(set) + + for command in commands: + if len(command) == 1: + + if _IsOption(command[0]): + global_options.add(command[0]) + else: + subcommands_map[name].add(command[0]) + + elif command: + + subcommand = command[-2] + arg = _FormatForCommand(command[-1]) + + if _IsOption(arg): + args_map = options_map + else: + args_map = subcommands_map + + args_map[subcommand].add(arg) + args_map[subcommand.replace('_', '-')].add(arg) + + return global_options, options_map, subcommands_map diff --git a/fire/completion_test.py b/fire/completion_test.py index 47d80b4a..eaf80a70 100644 --- a/fire/completion_test.py +++ b/fire/completion_test.py @@ -36,7 +36,10 @@ def testCompletionBashScript(self): script = completion._BashScript(name='command', commands=commands) # pylint: disable=protected-access self.assertIn('command', script) self.assertIn('halt', script) - self.assertIn('"$start" == "command"', script) + + assert_template = "{command})" + for last_command in ['command', 'halt']: + self.assertIn(assert_template.format(command=last_command), script) def testCompletionFishScript(self): # A sanity check test to make sure the fish completion script satisfies