diff --git a/.gitignore b/.gitignore index 61dd6cb62f7..d062b7563db 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,7 @@ gen # TestsResults TestsResults*.xml + +# Resharper settings +PowerShell.sln.DotSettings.user + diff --git a/src/System.Management.Automation/engine/CommandCompletion/CommandCompletion.cs b/src/System.Management.Automation/engine/CommandCompletion/CommandCompletion.cs index b92eadda029..74c3de498e1 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/CommandCompletion.cs +++ b/src/System.Management.Automation/engine/CommandCompletion/CommandCompletion.cs @@ -641,7 +641,7 @@ private static List InvokeLegacyTabExpansion(PowerShell powers char quote; var lastword = LastWordFinder.FindLastWord(legacyInput, out replacementIndex, out quote); replacementLength = legacyInput.Length - replacementIndex; - var helper = new CompletionExecutionHelper(powershell); + var helper = new PowerShellExecutionHelper(powershell); powershell.AddCommand("TabExpansion").AddArgument(legacyInput).AddArgument(lastword); @@ -749,7 +749,7 @@ internal CommandAndName(PSObject command, PSSnapinQualifiedName commandName) /// /// /// - internal static List PSv2GenerateMatchSetOfCmdlets(CompletionExecutionHelper helper, string lastWord, string quote, bool completingAtStartOfLine) + internal static List PSv2GenerateMatchSetOfCmdlets(PowerShellExecutionHelper helper, string lastWord, string quote, bool completingAtStartOfLine) { var results = new List(); bool isSnapinSpecified; @@ -864,7 +864,7 @@ private static void PrependSnapInNameForSameCmdletNames(CommandAndName[] cmdlets #region "Handle File Names" - internal static List PSv2GenerateMatchSetOfFiles(CompletionExecutionHelper helper, string lastWord, bool completingAtStartOfLine, string quote) + internal static List PSv2GenerateMatchSetOfFiles(PowerShellExecutionHelper helper, string lastWord, bool completingAtStartOfLine, string quote) { var results = new List(); @@ -926,7 +926,7 @@ internal static List PSv2GenerateMatchSetOfFiles(CompletionExe bool? isContainer = SafeGetProperty(combinedMatch.Item, "PSIsContainer"); string childName = SafeGetProperty(combinedMatch.Item, "PSChildName"); - string toolTip = CompletionExecutionHelper.SafeToString(combinedMatch.ConvertedPath); + string toolTip = PowerShellExecutionHelper.SafeToString(combinedMatch.ConvertedPath); if (isContainer != null && childName != null && toolTip != null) { @@ -1034,7 +1034,7 @@ private static T SafeGetProperty(PSObject psObject, string propertyName) return default(T); } - private static bool PSv2ShouldFullyQualifyPathsPath(CompletionExecutionHelper helper, string lastWord) + private static bool PSv2ShouldFullyQualifyPathsPath(PowerShellExecutionHelper helper, string lastWord) { // These are special cases, as they represent cases where the user expects to // see the full path. @@ -1068,7 +1068,7 @@ internal PathItemAndConvertedPath(string path, PSObject item, string convertedPa } } - private static List PSv2FindMatches(CompletionExecutionHelper helper, string path, bool shouldFullyQualifyPaths) + private static List PSv2FindMatches(PowerShellExecutionHelper helper, string path, bool shouldFullyQualifyPaths) { Diagnostics.Assert(!String.IsNullOrEmpty(path), "path should have a value"); var result = new List(); @@ -1113,9 +1113,9 @@ private static List PSv2FindMatches(CompletionExecutio } result.Add(new PathItemAndConvertedPath( - CompletionExecutionHelper.SafeToString(objectPath), + PowerShellExecutionHelper.SafeToString(objectPath), item, - CompletionExecutionHelper.SafeToString(convertedPath))); + PowerShellExecutionHelper.SafeToString(convertedPath))); } } diff --git a/src/System.Management.Automation/engine/CommandCompletion/CompletionAnalysis.cs b/src/System.Management.Automation/engine/CommandCompletion/CompletionAnalysis.cs index 32057fce1c4..cc04f93f175 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/CompletionAnalysis.cs +++ b/src/System.Management.Automation/engine/CommandCompletion/CompletionAnalysis.cs @@ -24,7 +24,7 @@ internal class CompletionContext internal Token TokenBeforeCursor { get; set; } internal IScriptPosition CursorPosition { get; set; } - internal CompletionExecutionHelper Helper { get; set; } + internal PowerShellExecutionHelper Helper { get; set; } internal Hashtable Options { get; set; } internal Dictionary CustomArgumentCompleters { get; set; } internal Dictionary NativeArgumentCompleters { get; set; } @@ -33,7 +33,7 @@ internal class CompletionContext internal int ReplacementLength { get; set; } internal ExecutionContext ExecutionContext { get; set; } internal PseudoBindingInfo PseudoBindingInfo { get; set; } - internal TypeDefinitionAst CurrentTypeDefinitionAst { get; set; } + internal TypeInferenceContext TypeInferenceContext { get; set; } internal bool GetOption(string option, bool @default) { @@ -96,15 +96,26 @@ private static bool IsCursorOutsideOfExtent(IScriptPosition cursor, IScriptExten return cursor.Offset < extent.StartOffset || cursor.Offset > extent.EndOffset; } - internal CompletionContext CreateCompletionContext(ExecutionContext executionContext) + + internal CompletionContext CreateCompletionContext(PowerShell powerShell) + { + var typeInferenceContext = new TypeInferenceContext(powerShell); + return InitializeCompletionContext(typeInferenceContext); + } + internal CompletionContext CreateCompletionContext(TypeInferenceContext typeInferenceContext) + { + return InitializeCompletionContext(typeInferenceContext); + } + + private CompletionContext InitializeCompletionContext(TypeInferenceContext typeInferenceContext) { Token tokenBeforeCursor = null; IScriptPosition positionForAstSearch = _cursorPosition; var adjustLineAndColumn = false; - var tokenAtCursor = _tokens.LastOrDefault(token => IsCursorWithinOrJustAfterExtent(_cursorPosition, token.Extent) && IsInterestingToken(token)); + var tokenAtCursor = InterstingTokenAtCursorOrDefault(_tokens, _cursorPosition); if (tokenAtCursor == null) { - tokenBeforeCursor = _tokens.LastOrDefault(token => IsCursorAfterExtent(_cursorPosition, token.Extent) && IsInterestingToken(token)); + tokenBeforeCursor = InterstingTokenBeforeCursorOrDefault(_tokens, _cursorPosition); if (tokenBeforeCursor != null) { positionForAstSearch = tokenBeforeCursor.Extent.EndScriptPosition; @@ -114,17 +125,22 @@ internal CompletionContext CreateCompletionContext(ExecutionContext executionCon else { var stringExpandableToken = tokenAtCursor as StringExpandableToken; - if (stringExpandableToken != null && stringExpandableToken.NestedTokens != null) + if (stringExpandableToken?.NestedTokens != null) { - tokenAtCursor = - stringExpandableToken.NestedTokens.LastOrDefault( - token => IsCursorWithinOrJustAfterExtent(_cursorPosition, token.Extent) && IsInterestingToken(token)) ?? stringExpandableToken; + tokenAtCursor = InterstingTokenAtCursorOrDefault(stringExpandableToken.NestedTokens, _cursorPosition) ?? stringExpandableToken; } } var asts = AstSearcher.FindAll(_ast, ast => IsCursorWithinOrJustAfterExtent(positionForAstSearch, ast.Extent), searchNestedScriptBlocks: true).ToList(); Diagnostics.Assert(tokenAtCursor == null || tokenBeforeCursor == null, "Only one of these tokens can be non-null"); + + if (typeInferenceContext.CurrentTypeDefinitionAst == null) + { + typeInferenceContext.CurrentTypeDefinitionAst = Ast.GetAncestorTypeDefinitionAst(asts.Last()); + } + ExecutionContext executionContext = typeInferenceContext.ExecutionContext; + return new CompletionContext { TokenAtCursor = tokenAtCursor, @@ -134,12 +150,24 @@ internal CompletionContext CreateCompletionContext(ExecutionContext executionCon Options = _options, ExecutionContext = executionContext, ReplacementIndex = adjustLineAndColumn ? _cursorPosition.Offset : 0, - CurrentTypeDefinitionAst = Ast.GetAncestorTypeDefinitionAst(asts.Last()), + TypeInferenceContext = typeInferenceContext, + Helper = typeInferenceContext.Helper, CustomArgumentCompleters = executionContext.CustomArgumentCompleters, NativeArgumentCompleters = executionContext.NativeArgumentCompleters, }; } + private static Token InterstingTokenAtCursorOrDefault(IEnumerable tokens, IScriptPosition cursorPosition) + { + return tokens.LastOrDefault(token => IsCursorWithinOrJustAfterExtent(cursorPosition, token.Extent) && IsInterestingToken(token)); + } + + private static Token InterstingTokenBeforeCursorOrDefault(IEnumerable tokens, IScriptPosition cursorPosition) + { + return tokens.LastOrDefault(token => IsCursorAfterExtent(cursorPosition, token.Extent) && IsInterestingToken(token)); + } + + private static Ast GetLastAstAtCursor(ScriptBlockAst scriptBlockAst, IScriptPosition cursorPosition) { var asts = AstSearcher.FindAll(scriptBlockAst, ast => IsCursorRightAfterExtent(cursorPosition, ast.Extent), searchNestedScriptBlocks: true); @@ -278,8 +306,7 @@ private static bool IsTokenTheSame(Token x, Token y) internal List GetResults(PowerShell powerShell, out int replacementIndex, out int replacementLength) { - var completionContext = CreateCompletionContext(powerShell.GetContextFromTLS()); - completionContext.Helper = new CompletionExecutionHelper(powerShell); + var completionContext = CreateCompletionContext(powerShell); PSLanguageMode? previousLanguageMode = null; try @@ -1167,8 +1194,7 @@ private List GetResultForString(CompletionContext completionCo cursorIndexInString = strValue.Length; var analysis = new CompletionAnalysis(_ast, _tokens, _cursorPosition, _options); - var subContext = analysis.CreateCompletionContext(completionContext.ExecutionContext); - subContext.Helper = completionContext.Helper; + var subContext = analysis.CreateCompletionContext(completionContext.TypeInferenceContext); int subReplaceIndex, subReplaceLength; var subResult = analysis.GetResultHelper(subContext, out subReplaceIndex, out subReplaceLength, true); diff --git a/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs b/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs index f7eb2ea69f5..56e88e50078 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs +++ b/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs @@ -83,7 +83,7 @@ public static IEnumerable CompleteCommand(string commandName, return CommandCompletion.EmptyCompletionResult; } - var helper = new CompletionExecutionHelper(PowerShell.Create(RunspaceMode.CurrentRunspace)); + var helper = new PowerShellExecutionHelper(PowerShell.Create(RunspaceMode.CurrentRunspace)); return CompleteCommand(new CompletionContext { WordToComplete = commandName, Helper = helper }, moduleName, commandTypes); } @@ -112,8 +112,8 @@ private static List CompleteCommand(CompletionContext context, lastAst = context.RelatedAsts.Last(); } - var powershell = context.Helper.CurrentPowerShell; - AddCommandWithPreferenceSetting(powershell, "Get-Command", typeof(GetCommandCommand)) + var powershell = context.Helper + .AddCommandWithPreferenceSetting("Get-Command", typeof(GetCommandCommand)) .AddParameter("All") .AddParameter("Name", commandName); @@ -177,8 +177,8 @@ private static List CompleteCommand(CompletionContext context, moduleName = commandName.Substring(0, indexOfFirstBackslash); commandName = commandName.Substring(indexOfFirstBackslash + 1); - var powershell = context.Helper.CurrentPowerShell; - AddCommandWithPreferenceSetting(powershell, "Get-Command", typeof(GetCommandCommand)) + var powershell = context.Helper + .AddCommandWithPreferenceSetting("Get-Command", typeof(GetCommandCommand)) .AddParameter("All") .AddParameter("Name", commandName) .AddParameter("Module", moduleName); @@ -451,8 +451,7 @@ internal static List CompleteModuleName(CompletionContext cont moduleName += "*"; } - var powershell = context.Helper.CurrentPowerShell; - AddCommandWithPreferenceSetting(powershell, "Get-Module", typeof(GetModuleCommand)).AddParameter("Name", moduleName); + var powershell = context.Helper.AddCommandWithPreferenceSetting("Get-Module", typeof(GetModuleCommand)).AddParameter("Name", moduleName); if (!loadedModulesOnly) { powershell.AddParameter("ListAvailable", true); @@ -1901,7 +1900,7 @@ private static IEnumerable NativeCommandArgumentCompletion_InferType } } - foreach (PSTypeName typeName in argumentAst.GetInferredType(context)) + foreach (PSTypeName typeName in AstTypeInference.InferTypeOf(argumentAst, context.TypeInferenceContext, TypeInferenceRuntimePermissions.AllowSafeEval)) { yield return typeName; } @@ -2453,35 +2452,6 @@ private static void RemoveLastNullCompletionResult(List result } } - private static bool NativeCompletionCimCommands_ParseTypeName(PSTypeName typename, out string cimNamespace, out string className) - { - cimNamespace = null; - className = null; - if (typename == null) - { - return false; - } - if (typename.Type != null) - { - return false; - } - - var match = Regex.Match(typename.Name, "(?.*)#(?.*)[/\\\\](?.*)"); - if (!match.Success) - { - return false; - } - - if (!match.Groups["NetTypeName"].Value.Equals(typeof(CimInstance).FullName, StringComparison.OrdinalIgnoreCase)) - { - return false; - } - - cimNamespace = match.Groups["CimNamespace"].Value; - className = match.Groups["CimClassName"].Value; - return true; - } - private static void NativeCompletionCimCommands( string parameter, Dictionary boundArguments, @@ -2543,7 +2513,7 @@ private static void NativeCompletionCimCommands( { foreach (PSTypeName typeName in cimClassTypeNames) { - if (NativeCompletionCimCommands_ParseTypeName(typeName, out pseudoboundCimNamespace, out pseudoboundClassName)) + if (TypeInferenceContext.ParseCimCommandsTypeName(typeName, out pseudoboundCimNamespace, out pseudoboundClassName)) { if (parameter.Equals("ResultClassName", StringComparison.OrdinalIgnoreCase)) { @@ -2913,11 +2883,11 @@ private static void NativeCompletionEventLogCommands(string logName, string para } var pattern = WildcardPattern.Get(logName, WildcardOptions.IgnoreCase); - var powershell = context.Helper.CurrentPowerShell; - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-EventLog").AddParameter("LogName", "*"); + var powerShellExecutionHelper = context.Helper; + var powershell = powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-EventLog").AddParameter("LogName", "*"); Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects != null) { @@ -2956,7 +2926,6 @@ private static void NativeCompletionJobCommands(string wordToComplete, string pa wordToComplete = wordToComplete ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref wordToComplete); - var powershell = context.Helper.CurrentPowerShell; if (!wordToComplete.EndsWith("*", StringComparison.Ordinal)) { @@ -2964,17 +2933,13 @@ private static void NativeCompletionJobCommands(string wordToComplete, string pa } var pattern = WildcardPattern.Get(wordToComplete, WildcardOptions.IgnoreCase); - if (paramName.Equals("Name", StringComparison.OrdinalIgnoreCase)) - { - AddCommandWithPreferenceSetting(powershell, "Get-Job", typeof(GetJobCommand)).AddParameter("Name", wordToComplete); - } - else - { - AddCommandWithPreferenceSetting(powershell, "Get-Job", typeof(GetJobCommand)).AddParameter("IncludeChildJob", true); - } + var paramIsName = paramName.Equals("Name", StringComparison.OrdinalIgnoreCase); + var (parameterName, value) = paramIsName ? ("Name", wordToComplete) : ("IncludeChildJob", (object)true); + var powerShellExecutionHelper = context.Helper; + powerShellExecutionHelper.AddCommandWithPreferenceSetting("Get-Job", typeof(GetJobCommand)).AddParameter(parameterName, value); Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects == null) return; @@ -3012,7 +2977,7 @@ private static void NativeCompletionJobCommands(string wordToComplete, string pa result.Add(CompletionResult.Null); } - else if (paramName.Equals("Name", StringComparison.OrdinalIgnoreCase)) + else if (paramIsName) { RemoveLastNullCompletionResult(result); @@ -3047,7 +3012,6 @@ private static void NativeCompletionScheduledJobCommands(string wordToComplete, wordToComplete = wordToComplete ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref wordToComplete); - var powershell = context.Helper.CurrentPowerShell; if (!wordToComplete.EndsWith("*", StringComparison.Ordinal)) { @@ -3055,17 +3019,18 @@ private static void NativeCompletionScheduledJobCommands(string wordToComplete, } var pattern = WildcardPattern.Get(wordToComplete, WildcardOptions.IgnoreCase); + var powerShellExecutionHelper = context.Helper; if (paramName.Equals("Name", StringComparison.OrdinalIgnoreCase)) { - AddCommandWithPreferenceSetting(powershell, "PSScheduledJob\\Get-ScheduledJob").AddParameter("Name", wordToComplete); + powerShellExecutionHelper.AddCommandWithPreferenceSetting("PSScheduledJob\\Get-ScheduledJob").AddParameter("Name", wordToComplete); } else { - AddCommandWithPreferenceSetting(powershell, "PSScheduledJob\\Get-ScheduledJob"); + powerShellExecutionHelper.AddCommandWithPreferenceSetting("PSScheduledJob\\Get-ScheduledJob"); } Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects == null) return; @@ -3173,24 +3138,24 @@ private static void NativeCompletionProcessCommands(string wordToComplete, strin wordToComplete = wordToComplete ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref wordToComplete); - var powershell = context.Helper.CurrentPowerShell; if (!wordToComplete.EndsWith("*", StringComparison.Ordinal)) { wordToComplete += "*"; } + var powerShellExecutionHelper = context.Helper; if (paramName.Equals("Id", StringComparison.OrdinalIgnoreCase)) { - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Process"); + powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Process"); } else { - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Process").AddParameter("Name", wordToComplete); + powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Process").AddParameter("Name", wordToComplete); } Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects == null) return; @@ -3256,16 +3221,16 @@ private static void NativeCompletionProviderCommands(string providerName, string providerName = providerName ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref providerName); - var powershell = context.Helper.CurrentPowerShell; + if (!providerName.EndsWith("*", StringComparison.Ordinal)) { providerName += "*"; } - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-PSProvider").AddParameter("PSProvider", providerName); - Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var powerShellExecutionHelper = context.Helper; + powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-PSProvider").AddParameter("PSProvider", providerName); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out _); if (psObjects == null) return; @@ -3301,19 +3266,20 @@ private static void NativeCompletionDriveCommands(string wordToComplete, string wordToComplete = wordToComplete ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref wordToComplete); - var powershell = context.Helper.CurrentPowerShell; if (!wordToComplete.EndsWith("*", StringComparison.Ordinal)) { wordToComplete += "*"; } - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-PSDrive").AddParameter("Name", wordToComplete); + var powerShellExecutionHelper = context.Helper; + var powershell = powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-PSDrive") + .AddParameter("Name", wordToComplete); if (psProvider != null) powershell.AddParameter("PSProvider", psProvider); - Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out _); if (psObjects != null) { foreach (dynamic driveInfo in psObjects) @@ -3347,7 +3313,6 @@ private static void NativeCompletionServiceCommands(string wordToComplete, strin wordToComplete = wordToComplete ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref wordToComplete); - var powershell = context.Helper.CurrentPowerShell; if (!wordToComplete.EndsWith("*", StringComparison.Ordinal)) { @@ -3355,15 +3320,17 @@ private static void NativeCompletionServiceCommands(string wordToComplete, strin } Exception exceptionThrown; + var powerShellExecutionHelper = context.Helper; if (paramName.Equals("DisplayName", StringComparison.OrdinalIgnoreCase)) { RemoveLastNullCompletionResult(result); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Service") - .AddParameter("DisplayName", wordToComplete); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Sort-Object") + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Service") + .AddParameter("DisplayName", wordToComplete) + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Sort-Object") .AddParameter("Property", "DisplayName"); - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects != null) { foreach (dynamic serviceInfo in psObjects) @@ -3393,8 +3360,8 @@ private static void NativeCompletionServiceCommands(string wordToComplete, strin { RemoveLastNullCompletionResult(result); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Service").AddParameter("Name", wordToComplete); - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Service").AddParameter("Name", wordToComplete); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects != null) { foreach (dynamic serviceInfo in psObjects) @@ -3433,16 +3400,14 @@ private static void NativeCompletionVariableCommands(string variableName, string variableName = variableName ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref variableName); - var powershell = context.Helper.CurrentPowerShell; - if (!variableName.EndsWith("*", StringComparison.Ordinal)) { variableName += "*"; } - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Get-Variable").AddParameter("Name", variableName); - Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var powerShellExecutionHelper = context.Helper; + var powershell = powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Get-Variable").AddParameter("Name", variableName); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out _); if (psObjects == null) return; @@ -3489,11 +3454,11 @@ private static void NativeCompletionAliasCommands(string commandName, string par RemoveLastNullCompletionResult(result); + var powerShellExecutionHelper = context.Helper; if (paramName.Equals("Name", StringComparison.OrdinalIgnoreCase)) { commandName = commandName ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref commandName); - var powershell = context.Helper.CurrentPowerShell; if (!commandName.EndsWith("*", StringComparison.Ordinal)) { @@ -3501,8 +3466,8 @@ private static void NativeCompletionAliasCommands(string commandName, string par } Exception exceptionThrown; - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Get-Alias").AddParameter("Name", commandName); - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var powershell = powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Get-Alias").AddParameter("Name", commandName); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects != null) { foreach (dynamic aliasInfo in psObjects) @@ -3531,12 +3496,12 @@ private static void NativeCompletionAliasCommands(string commandName, string par // Complete for the parameter Definition // Available commands const CommandTypes commandTypes = CommandTypes.Cmdlet | CommandTypes.Function | CommandTypes.ExternalScript | CommandTypes.Workflow | CommandTypes.Configuration; - var commandResults = CompleteCommand(new CompletionContext { WordToComplete = commandName, Helper = context.Helper }, null, commandTypes); + var commandResults = CompleteCommand(new CompletionContext { WordToComplete = commandName, Helper = powerShellExecutionHelper }, null, commandTypes); if (commandResults != null && commandResults.Count > 0) result.AddRange(commandResults); // The parameter Definition takes a file - var fileResults = new List(CompleteFilename(new CompletionContext { WordToComplete = commandName, Helper = context.Helper })); + var fileResults = new List(CompleteFilename(new CompletionContext { WordToComplete = commandName, Helper = powerShellExecutionHelper })); if (fileResults.Count > 0) result.AddRange(fileResults); } @@ -3555,16 +3520,16 @@ private static void NativeCompletionTraceSourceCommands(string traceSourceName, traceSourceName = traceSourceName ?? string.Empty; var quote = HandleDoubleAndSingleQuote(ref traceSourceName); - var powershell = context.Helper.CurrentPowerShell; if (!traceSourceName.EndsWith("*", StringComparison.Ordinal)) { traceSourceName += "*"; } - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Get-TraceSource").AddParameter("Name", traceSourceName); + var powerShellExecutionHelper = context.Helper; + var powershell = powerShellExecutionHelper.AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Get-TraceSource").AddParameter("Name", traceSourceName); Exception exceptionThrown; - var psObjects = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psObjects = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psObjects == null) return; @@ -3792,11 +3757,11 @@ private static void NativeCompletionMemberName(string wordToComplete, List CompleteFilename(string fileName) return CommandCompletion.EmptyCompletionResult; } - var helper = new CompletionExecutionHelper(PowerShell.Create(RunspaceMode.CurrentRunspace)); + var helper = new PowerShellExecutionHelper(PowerShell.Create(RunspaceMode.CurrentRunspace)); return CompleteFilename(new CompletionContext { WordToComplete = fileName, Helper = helper }); } @@ -4115,16 +4080,13 @@ internal static IEnumerable CompleteFilename(CompletionContext } else { - var powershell = context.Helper.CurrentPowerShell; - var executionContext = powershell.GetContextFromTLS(); - // We want to prefer relative paths in a completion result unless the user has already // specified a drive or portion of the path. - string unused; + var executionContext = context.ExecutionContext; var defaultRelative = string.IsNullOrWhiteSpace(wordToComplete) || (wordToComplete.IndexOfAny(Utils.Separators.Directory) != 0 && !Regex.Match(wordToComplete, @"^~[\\/]+.*").Success && - !executionContext.LocationGlobber.IsAbsolutePath(wordToComplete, out unused)); + !executionContext.LocationGlobber.IsAbsolutePath(wordToComplete, out _)); var relativePaths = context.GetOption("RelativePaths", @default: defaultRelative); var useLiteralPath = context.GetOption("LiteralPaths", @default: false); @@ -4133,18 +4095,20 @@ internal static IEnumerable CompleteFilename(CompletionContext wordToComplete = WildcardPattern.Escape(wordToComplete, Utils.Separators.StarOrQuestion); } - if (!defaultRelative && wordToComplete.Length >= 2 && wordToComplete[1] == ':' && char.IsLetter(wordToComplete[0]) && context.ExecutionContext != null) + if (!defaultRelative && wordToComplete.Length >= 2 && wordToComplete[1] == ':' && char.IsLetter(wordToComplete[0]) && executionContext != null) { // We don't actually need the drive, but the drive must be "mounted" in PowerShell before completion // can succeed. This call will mount the drive if it wasn't already. - context.ExecutionContext.SessionState.Drive.GetAtScope(wordToComplete.Substring(0, 1), "global"); + executionContext.SessionState.Drive.GetAtScope(wordToComplete.Substring(0, 1), "global"); } - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Resolve-Path") + var powerShellExecutionHelper = context.Helper; + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Resolve-Path") .AddParameter("Path", wordToComplete + "*"); Exception exceptionThrown; - var psobjs = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psobjs = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psobjs != null) { @@ -4263,11 +4227,12 @@ internal static IEnumerable CompleteFilename(CompletionContext if (!hiddenFilesAreHandled) { - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-ChildItem") + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-ChildItem") .AddParameter("Path", wordToComplete + "*") .AddParameter("Hidden", true); - var hiddenItems = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var hiddenItems = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (hiddenItems != null && hiddenItems.Count > 0) { foreach (var hiddenItem in hiddenItems) @@ -4402,9 +4367,10 @@ internal static IEnumerable CompleteFilename(CompletionContext } else { - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Item") + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Item") .AddParameter("LiteralPath", path); - var items = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var items = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (items != null && items.Count == 1) { dynamic item = items[0]; @@ -4413,9 +4379,10 @@ internal static IEnumerable CompleteFilename(CompletionContext if (containerOnly && !isContainer) continue; - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Convert-Path") + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Convert-Path") .AddParameter("LiteralPath", item.PSPath); - var tooltips = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var tooltips = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); string tooltip = null, listItemText = item.PSChildName; if (tooltips != null && tooltips.Count == 1) { @@ -4522,7 +4489,7 @@ public static IEnumerable CompleteVariable(string variableName return CommandCompletion.EmptyCompletionResult; } - var helper = new CompletionExecutionHelper(PowerShell.Create(RunspaceMode.CurrentRunspace)); + var helper = new PowerShellExecutionHelper(PowerShell.Create(RunspaceMode.CurrentRunspace)); return CompleteVariable(new CompletionContext { WordToComplete = variableName, Helper = helper }); } @@ -4625,7 +4592,7 @@ internal static List CompleteVariable(CompletionContext contex var commandAst = ast as CommandAst; if (commandAst != null) { - PSTypeName discoveredType = ast.GetInferredType(context).FirstOrDefault(); + PSTypeName discoveredType = AstTypeInference.InferTypeOf(ast, context.TypeInferenceContext, TypeInferenceRuntimePermissions.AllowSafeEval).FirstOrDefault(); if (discoveredType != null) { tooltip = StringUtil.Format("[{0}]${1}", discoveredType.Name, userPath); @@ -4660,12 +4627,13 @@ internal static List CompleteVariable(CompletionContext contex } } - var powershell = context.Helper.CurrentPowerShell; - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Item").AddParameter("Path", pattern); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Sort-Object").AddParameter("Property", "Name"); + var powerShellExecutionHelper = context.Helper; + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Item").AddParameter("Path", pattern) + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Sort-Object").AddParameter("Property", "Name"); Exception exceptionThrown; - var psobjs = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + var psobjs = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psobjs != null) { foreach (dynamic psobj in psobjs) @@ -4696,11 +4664,11 @@ internal static List CompleteVariable(CompletionContext contex if (colon == -1 && "env".StartsWith(wordToComplete, StringComparison.OrdinalIgnoreCase)) { - powershell = context.Helper.CurrentPowerShell; - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-Item").AddParameter("Path", "env:*"); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Sort-Object").AddParameter("Property", "Key"); + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-Item").AddParameter("Path", "env:*") + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Sort-Object").AddParameter("Property", "Key"); - psobjs = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + psobjs = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psobjs != null) { foreach (dynamic psobj in psobjs) @@ -4736,9 +4704,10 @@ internal static List CompleteVariable(CompletionContext contex { // If no drive was specified, then look for matching drives/scopes pattern = wordToComplete + "*"; - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Management\\Get-PSDrive").AddParameter("Name", pattern); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Sort-Object").AddParameter("Property", "Name"); - psobjs = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Management\\Get-PSDrive").AddParameter("Name", pattern) + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Sort-Object").AddParameter("Property", "Name"); + psobjs = powerShellExecutionHelper.ExecuteCurrentPowerShell(out exceptionThrown); if (psobjs != null) { foreach (var psobj in psobjs) @@ -4873,15 +4842,13 @@ internal static List CompleteComment(CompletionContext context if (!matchResult.Success) { return results; } string wordToComplete = matchResult.Groups[1].Value; - PowerShell powershell = context.Helper.CurrentPowerShell; Collection psobjs; - Exception exceptionThrown; int entryId; if (Regex.IsMatch(wordToComplete, @"^[0-9]+$") && LanguagePrimitives.TryConvertTo(wordToComplete, out entryId)) { - AddCommandWithPreferenceSetting(powershell, "Get-History", typeof(GetHistoryCommand)).AddParameter("Id", entryId); - psobjs = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + context.Helper.AddCommandWithPreferenceSetting("Get-History", typeof(GetHistoryCommand)).AddParameter("Id", entryId); + psobjs = context.Helper.ExecuteCurrentPowerShell(out _); if (psobjs != null && psobjs.Count == 1) { @@ -4905,9 +4872,9 @@ internal static List CompleteComment(CompletionContext context } wordToComplete = "*" + wordToComplete + "*"; - AddCommandWithPreferenceSetting(powershell, "Get-History", typeof(GetHistoryCommand)); + context.Helper.AddCommandWithPreferenceSetting("Get-History", typeof(GetHistoryCommand)); - psobjs = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown); + psobjs = context.Helper.ExecuteCurrentPowerShell(out _); var pattern = WildcardPattern.Get(wordToComplete, WildcardOptions.IgnoreCase); if (psobjs != null) @@ -5074,7 +5041,7 @@ internal static List CompleteMember(CompletionContext context, } else { - inferredTypes = targetExpr.GetInferredType(context).ToArray(); + inferredTypes = AstTypeInference.InferTypeOf(targetExpr, context.TypeInferenceContext, TypeInferenceRuntimePermissions.AllowSafeEval).ToArray(); } if (inferredTypes != null && inferredTypes.Length > 0) @@ -5196,7 +5163,7 @@ private static void CompleteMemberByInferredType(CompletionContext context, IEnu continue; } typeNameUsed.Add(psTypeName.Name); - var members = GetMembersByInferredType(psTypeName, context, isStatic, filter); + var members = context.TypeInferenceContext.GetMembersByInferredType(psTypeName, isStatic, filter); foreach (var member in members) { AddInferredMember(member, memberNamePattern, results); @@ -5214,11 +5181,13 @@ private static void CompleteMemberByInferredType(CompletionContext context, IEnu if (results.Count > 0) { // Sort the results - AddCommandWithPreferenceSetting(context.Helper.CurrentPowerShell, "Microsoft.PowerShell.Utility\\Sort-Object") + var powerShellExecutionHelper = context.Helper; + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Sort-Object") .AddParameter("Property", new[] { "ResultType", "ListItemText" }) .AddParameter("Unique"); Exception unused; - var sortedResults = context.Helper.ExecuteCurrentPowerShell(out unused, results); + var sortedResults = powerShellExecutionHelper.ExecuteCurrentPowerShell(out unused, results); results.Clear(); results.AddRange(sortedResults.Select(psobj => PSObject.Base(psobj) as CompletionResult)); } @@ -5372,153 +5341,6 @@ private static bool IsConstructor(object member) return false; } - internal static IEnumerable GetMembersByInferredType(PSTypeName typename, CompletionContext context, bool @static, Func filter) - { - List results = new List(); - - Func filterToCall = filter; - if (typename.Type != null) - { - if (context.CurrentTypeDefinitionAst == null || context.CurrentTypeDefinitionAst.Type != typename.Type) - { - if (filterToCall == null) - filterToCall = o => !IsMemberHidden(o); - else - filterToCall = o => !IsMemberHidden(o) && filter(o); - } - IEnumerable elementTypes; - if (typename.Type.IsArray) - { - elementTypes = new[] { typename.Type.GetElementType() }; - } - else - { - elementTypes = typename.Type.GetInterfaces().Where( - t => t.GetTypeInfo().IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>)); - } - foreach (var type in elementTypes.Prepend(typename.Type)) - { - // Look in the type table first. - if (!@static) - { - var consolidatedString = DotNetAdapter.GetInternedTypeNameHierarchy(type); - results.AddRange(context.ExecutionContext.TypeTable.GetMembers(consolidatedString)); - } - - var members = @static - ? PSObject.dotNetStaticAdapter.BaseGetMembers(type) - : PSObject.dotNetInstanceAdapter.GetPropertiesAndMethods(type, false); - results.AddRange(filterToCall != null ? members.Where(filterToCall) : members); - } - } - else if (typename.TypeDefinitionAst != null) - { - if (context.CurrentTypeDefinitionAst != typename.TypeDefinitionAst) - { - if (filterToCall == null) - filterToCall = o => !IsMemberHidden(o); - else - filterToCall = o => !IsMemberHidden(o) && filter(o); - } - - bool foundConstructor = false; - foreach (var member in typename.TypeDefinitionAst.Members) - { - bool add = false; - var propertyMember = member as PropertyMemberAst; - if (propertyMember != null) - { - if (propertyMember.IsStatic == @static) - { - add = true; - } - } - else - { - var functionMember = (FunctionMemberAst)member; - if (functionMember.IsStatic == @static) - { - add = true; - } - foundConstructor |= functionMember.IsConstructor; - } - - if (filterToCall != null && add) - { - add = filterToCall(member); - } - - if (add) - { - results.Add(member); - } - } - - //iterate through bases/interfaces - foreach (var baseType in typename.TypeDefinitionAst.BaseTypes) - { - TypeName baseTypeName = baseType.TypeName as TypeName; - if (baseTypeName != null) - { - TypeDefinitionAst baseTypeDefinitionAst = baseTypeName._typeDefinitionAst; - results.AddRange(GetMembersByInferredType(new PSTypeName(baseTypeDefinitionAst), context, @static, filterToCall)); - } - } - - // Add stuff from our base class System.Object. - if (@static) - { - // Don't add base class constructors - if (filter == null) - { - filterToCall = o => !IsConstructor(o); - } - else - { - filterToCall = o => !IsConstructor(o) && filter(o); - } - - if (!foundConstructor) - { - results.Add( - new CompilerGeneratedMemberFunctionAst(PositionUtilities.EmptyExtent, typename.TypeDefinitionAst, - SpecialMemberFunctionType.DefaultConstructor)); - } - } - else - { - // Reset the filter because the recursive call will add IsHidden back if necessary. - filterToCall = filter; - } - results.AddRange(GetMembersByInferredType(new PSTypeName(typeof(object)), context, @static, filterToCall)); - } - else - { - // Look in the type table first. - if (!@static) - { - var consolidatedString = new ConsolidatedString(new string[] { typename.Name }); - results.AddRange(context.ExecutionContext.TypeTable.GetMembers(consolidatedString)); - } - - string cimNamespace; - string className; - if (NativeCompletionCimCommands_ParseTypeName(typename, out cimNamespace, out className)) - { - AddCommandWithPreferenceSetting(context.Helper.CurrentPowerShell, "CimCmdlets\\Get-CimClass") - .AddParameter("Namespace", cimNamespace) - .AddParameter("Class", className); - Exception unused; - var classes = context.Helper.ExecuteCurrentPowerShell(out unused); - foreach (var @class in classes.Select(PSObject.Base).OfType()) - { - results.AddRange(filterToCall != null ? @class.CimClassProperties.Where(filterToCall) : @class.CimClassProperties); - } - } - } - - return results; - } #endregion Members @@ -6041,7 +5863,7 @@ public static IEnumerable CompleteType(string typeName) ? PowerShell.Create() : PowerShell.Create(RunspaceMode.CurrentRunspace); - var helper = new CompletionExecutionHelper(powershell); + var helper = new PowerShellExecutionHelper(powershell); return CompleteType(new CompletionContext { WordToComplete = typeName, Helper = helper }); } @@ -6329,7 +6151,7 @@ internal static List CompleteHashtableKey(CompletionContext co { var result = new List(); CompleteMemberByInferredType( - completionContext, typeAst.GetInferredType(completionContext), + completionContext, AstTypeInference.InferTypeOf(typeAst, completionContext.TypeInferenceContext, TypeInferenceRuntimePermissions.AllowSafeEval), result, completionContext.WordToComplete + "*", IsWriteablePropertyMember, isStatic: false); return result; } @@ -6432,7 +6254,7 @@ internal static List CompleteHashtableKey(CompletionContext co switch (binding.CommandName) { case "New-Object": - var inferredType = commandAst.GetInferredType(completionContext); + var inferredType = AstTypeInference.InferTypeOf(commandAst, completionContext.TypeInferenceContext, TypeInferenceRuntimePermissions.AllowSafeEval); var result = new List(); CompleteMemberByInferredType( completionContext, inferredType, @@ -6471,30 +6293,6 @@ private static List GetSpecialHashTableKeyMembers(params strin #region Helpers - internal static PowerShell AddCommandWithPreferenceSetting(PowerShell powershell, string command, Type type = null) - { - Diagnostics.Assert(powershell != null, "the passed-in powershell cannot be null"); - Diagnostics.Assert(!String.IsNullOrWhiteSpace(command), "the passed-in command name should not be null or whitespaces"); - - if (type != null) - { - var cmdletInfo = new CmdletInfo(command, type); - powershell.AddCommand(cmdletInfo); - } - else - { - powershell.AddCommand(command); - } - powershell - .AddParameter("ErrorAction", ActionPreference.Ignore) - .AddParameter("WarningAction", ActionPreference.Ignore) - .AddParameter("InformationAction", ActionPreference.Ignore) - .AddParameter("Verbose", false) - .AddParameter("Debug", false); - - return powershell; - } - internal static bool IsPathSafelyExpandable(ExpandableStringExpressionAst expandableStringAst, string extraText, ExecutionContext executionContext, out string expandedString) { expandedString = null; @@ -6620,14 +6418,14 @@ internal static void CompleteMemberHelper( value = new[] { value }; } - var powershell = context.Helper.CurrentPowerShell; - // Instead of Get-Member, we access the members directly and send as input to the pipe. - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Core\\Where-Object") + var powerShellExecutionHelper = context.Helper; + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Core\\Where-Object") .AddParameter("Property", "Name") .AddParameter("Like") - .AddParameter("Value", memberName); - AddCommandWithPreferenceSetting(powershell, "Microsoft.PowerShell.Utility\\Sort-Object") + .AddParameter("Value", memberName) + .AddCommandWithPreferenceSetting("Microsoft.PowerShell.Utility\\Sort-Object") .AddParameter("Property", new object[] { "MemberType", "Name" }); IEnumerable members; @@ -6644,8 +6442,7 @@ internal static void CompleteMemberHelper( { members = PSObject.AsPSObject(value).Members; } - Exception exceptionThrown; - var sortedMembers = context.Helper.ExecuteCurrentPowerShell(out exceptionThrown, members); + var sortedMembers = powerShellExecutionHelper.ExecuteCurrentPowerShell(out _, members); foreach (var member in sortedMembers) { diff --git a/src/System.Management.Automation/engine/CommandCompletion/PseudoParameterBinder.cs b/src/System.Management.Automation/engine/CommandCompletion/PseudoParameterBinder.cs index 9715c3db2f1..cb13722f202 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/PseudoParameterBinder.cs +++ b/src/System.Management.Automation/engine/CommandCompletion/PseudoParameterBinder.cs @@ -1253,10 +1253,7 @@ private bool PrepareCommandElements(ExecutionContext context) var expressionArgument = _commandElements[commandIndex] as ExpressionAst; if (expressionArgument != null) { - if (argumentsToGetDynamicParameters != null) - { - argumentsToGetDynamicParameters.Add(expressionArgument.Extent.Text); - } + argumentsToGetDynamicParameters?.Add(expressionArgument.Extent.Text); _arguments.Add(new AstPair(null, expressionArgument)); } @@ -1338,79 +1335,76 @@ private bool PrepareCommandElements(ExecutionContext context) if (_commandAst.IsInWorkflow()) { var converterType = Type.GetType(Utils.WorkflowType); - if (converterType != null) + var activityParameters = (Dictionary) converterType?.GetMethod("GetActivityParameters").Invoke(null, new object[] { _commandAst }); + if (activityParameters != null) { - var activityParameters = (Dictionary)converterType.GetMethod("GetActivityParameters").Invoke(null, new object[] { _commandAst }); - if (activityParameters != null) - { - bool needToRemoveReplacedProperty = activityParameters.ContainsKey("PSComputerName") && - !activityParameters.ContainsKey("ComputerName"); + bool needToRemoveReplacedProperty = activityParameters.ContainsKey("PSComputerName") && + !activityParameters.ContainsKey("ComputerName"); - var parametersToAdd = new List(); - var attrCollection = new Collection { new ParameterAttribute() }; - foreach (var pair in activityParameters) + var parametersToAdd = new List(); + var attrCollection = new Collection { new ParameterAttribute() }; + foreach (var pair in activityParameters) + { + if (psuedoWorkflowCommand || !_bindableParameters.BindableParameters.ContainsKey(pair.Key)) { - if (psuedoWorkflowCommand || !_bindableParameters.BindableParameters.ContainsKey(pair.Key)) - { - Type parameterType = GetActualActivityParameterType(pair.Value); - var runtimeDefinedParameter = new RuntimeDefinedParameter(pair.Key, parameterType, attrCollection); - var compiledCommandParameter = new CompiledCommandParameter(runtimeDefinedParameter, false) { IsInAllSets = true }; - var mergedCompiledCommandParameter = new MergedCompiledCommandParameter(compiledCommandParameter, ParameterBinderAssociation.DeclaredFormalParameters); - parametersToAdd.Add(mergedCompiledCommandParameter); - } + Type parameterType = GetActualActivityParameterType(pair.Value); + var runtimeDefinedParameter = new RuntimeDefinedParameter(pair.Key, parameterType, attrCollection); + var compiledCommandParameter = new CompiledCommandParameter(runtimeDefinedParameter, false) { IsInAllSets = true }; + var mergedCompiledCommandParameter = new MergedCompiledCommandParameter(compiledCommandParameter, ParameterBinderAssociation.DeclaredFormalParameters); + parametersToAdd.Add(mergedCompiledCommandParameter); } - if (parametersToAdd.Any()) + } + if (parametersToAdd.Any()) + { + var mergedBindableParameters = new MergedCommandParameterMetadata(); + if (!psuedoWorkflowCommand) { - var mergedBindableParameters = new MergedCommandParameterMetadata(); - if (!psuedoWorkflowCommand) - { - mergedBindableParameters.ReplaceMetadata(_bindableParameters); - } - foreach (var p in parametersToAdd) - { - mergedBindableParameters.BindableParameters.Add(p.Parameter.Name, p); - } - _bindableParameters = mergedBindableParameters; + mergedBindableParameters.ReplaceMetadata(_bindableParameters); } + foreach (var p in parametersToAdd) + { + mergedBindableParameters.BindableParameters.Add(p.Parameter.Name, p); + } + _bindableParameters = mergedBindableParameters; + } - // Remove common parameters that are supported by all commands, but not - // by workflows - bool fixedReadOnly = false; - foreach (var ignored in _ignoredWorkflowParameters) + // Remove common parameters that are supported by all commands, but not + // by workflows + bool fixedReadOnly = false; + foreach (var ignored in _ignoredWorkflowParameters) + { + if (_bindableParameters.BindableParameters.ContainsKey(ignored)) { - if (_bindableParameters.BindableParameters.ContainsKey(ignored)) + // However, some ignored parameters are explicitly implemented by + // activities, so keep them. + if (!activityParameters.ContainsKey(ignored)) { - // However, some ignored parameters are explicitly implemented by - // activities, so keep them. - if (!activityParameters.ContainsKey(ignored)) + if (!fixedReadOnly) { - if (!fixedReadOnly) - { - _bindableParameters.ResetReadOnly(); - fixedReadOnly = true; - } - - _bindableParameters.BindableParameters.Remove(ignored); + _bindableParameters.ResetReadOnly(); + fixedReadOnly = true; } + + _bindableParameters.BindableParameters.Remove(ignored); } } + } - if (_bindableParameters.BindableParameters.ContainsKey("ComputerName") && needToRemoveReplacedProperty) + if (_bindableParameters.BindableParameters.ContainsKey("ComputerName") && needToRemoveReplacedProperty) + { + if (!fixedReadOnly) { - if (!fixedReadOnly) - { - _bindableParameters.ResetReadOnly(); - fixedReadOnly = true; - } + _bindableParameters.ResetReadOnly(); + fixedReadOnly = true; + } - _bindableParameters.BindableParameters.Remove("ComputerName"); - string aliasOfComputerName = (from aliasPair in _bindableParameters.AliasedParameters - where String.Equals("ComputerName", aliasPair.Value.Parameter.Name) - select aliasPair.Key).FirstOrDefault(); - if (aliasOfComputerName != null) - { - _bindableParameters.AliasedParameters.Remove(aliasOfComputerName); - } + _bindableParameters.BindableParameters.Remove("ComputerName"); + string aliasOfComputerName = (from aliasPair in _bindableParameters.AliasedParameters + where String.Equals("ComputerName", aliasPair.Value.Parameter.Name) + select aliasPair.Key).FirstOrDefault(); + if (aliasOfComputerName != null) + { + _bindableParameters.AliasedParameters.Remove(aliasOfComputerName); } } } @@ -1451,32 +1445,36 @@ private CommandProcessorBase PrepareFromAst(ExecutionContext context, out string } ast.Visit(exportVisitor); - resolvedCommandName = _commandAst.GetCommandName(); CommandProcessorBase commandProcessor = null; - string alias; - int resolvedAliasCount = 0; - while (exportVisitor.DiscoveredAliases.TryGetValue(resolvedCommandName, out alias)) + resolvedCommandName = _commandAst.GetCommandName(); + if (resolvedCommandName != null) { - resolvedAliasCount += 1; - if (resolvedAliasCount > 5) - break; // give up, assume it's recursive - resolvedCommandName = alias; - } + string alias; + int resolvedAliasCount = 0; - FunctionDefinitionAst functionDefinitionAst; - if (exportVisitor.DiscoveredFunctions.TryGetValue(resolvedCommandName, out functionDefinitionAst)) - { - // We could use the IAstToScriptBlockConverter to get the actual script block, but that can be fairly expensive for workflows. - // IAstToScriptBlockConverter is public, so we might consider converting non-workflows, but the interface isn't really designed - // for Intellisense, so we can't really expect good performance, so instead we'll just fall back on the actual - // parameters we see in the ast. - var scriptBlock = functionDefinitionAst.IsWorkflow - ? CreateFakeScriptBlockForWorkflow(functionDefinitionAst) - : new ScriptBlock(functionDefinitionAst, functionDefinitionAst.IsFilter); - commandProcessor = CommandDiscovery.CreateCommandProcessorForScript(scriptBlock, context, true, context.EngineSessionState); - } + while (exportVisitor.DiscoveredAliases.TryGetValue(resolvedCommandName, out alias)) + { + resolvedAliasCount += 1; + if (resolvedAliasCount > 5) + break; // give up, assume it's recursive + resolvedCommandName = alias; + } + FunctionDefinitionAst functionDefinitionAst; + if (exportVisitor.DiscoveredFunctions.TryGetValue(resolvedCommandName, out functionDefinitionAst)) + { + // We could use the IAstToScriptBlockConverter to get the actual script block, but that can be fairly expensive for workflows. + // IAstToScriptBlockConverter is public, so we might consider converting non-workflows, but the interface isn't really designed + // for Intellisense, so we can't really expect good performance, so instead we'll just fall back on the actual + // parameters we see in the ast. + var scriptBlock = functionDefinitionAst.IsWorkflow + ? CreateFakeScriptBlockForWorkflow(functionDefinitionAst) + : new ScriptBlock(functionDefinitionAst, functionDefinitionAst.IsFilter); + commandProcessor = CommandDiscovery.CreateCommandProcessorForScript(scriptBlock, context, true, context.EngineSessionState); + } + + } return commandProcessor; } @@ -1525,7 +1523,7 @@ private static ScriptBlock CreateFakeScriptBlockForWorkflow(FunctionDefinitionAs var paramBlockAst = functionDefinitionAst.Body.ParamBlock; if (paramBlockAst != null) { - var outputTypeAttrs = paramBlockAst.Attributes.Where(attribute => typeof(OutputTypeAttribute).Equals(attribute.TypeName.GetReflectionAttributeType())); + var outputTypeAttrs = paramBlockAst.Attributes.Where(attribute => typeof(OutputTypeAttribute) == attribute.TypeName.GetReflectionAttributeType()); foreach (AttributeAst attributeAst in outputTypeAttrs) { diff --git a/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs b/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs new file mode 100644 index 00000000000..0297f000c67 --- /dev/null +++ b/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs @@ -0,0 +1,1577 @@ +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Linq; +using System.Management.Automation.Language; +using System.Management.Automation.Runspaces; +using System.Reflection; +using System.Text.RegularExpressions; +using Microsoft.PowerShell.Commands; + +using CimClass = Microsoft.Management.Infrastructure.CimClass; +using CimInstance = Microsoft.Management.Infrastructure.CimInstance; + +namespace System.Management.Automation +{ + /// + /// Enum describing permissions to use runtime evaluation during type inference + /// + public enum TypeInferenceRuntimePermissions { + /// + /// No runtime use is allowed + /// + None = 0, + /// + /// Use of SafeExprEvaluator visitor is allowed + /// + AllowSafeEval = 1, + } + + /// + /// static class containing methods to work with type inference of abstract syntax trees + /// + internal static class AstTypeInference + { + /// + /// Infers the type that the result of executing a statement would have without using runtime safe eval + /// + /// the ast to infer the type from + /// + public static IList InferTypeOf(Ast ast) + { + return InferTypeOf(ast, TypeInferenceRuntimePermissions.None); + } + + /// + /// Infers the type that the result of executing a statement would have + /// + /// the ast to infer the type from + /// The runtime usage permissions allowed during type inference + /// + public static IList InferTypeOf(Ast ast, TypeInferenceRuntimePermissions evalPermissions) + { + return InferTypeOf(ast, PowerShell.Create(RunspaceMode.CurrentRunspace), evalPermissions); + } + + /// + /// Infers the type that the result of executing a statement would have without using runtime safe eval + /// + /// the ast to infer the type from + /// the instance of powershell to user for expression evalutaion needed for type inference + /// + public static IList InferTypeOf(Ast ast, PowerShell powerShell) + { + return InferTypeOf(ast, powerShell, TypeInferenceRuntimePermissions.None); + } + + /// + /// Infers the type that the result of executing a statement would have + /// + /// the ast to infer the type from + /// the instance of powershell to user for expression evalutaion needed for type inference + /// The runtime usage permissions allowed during type inference + /// + public static IList InferTypeOf(Ast ast, PowerShell powerShell, TypeInferenceRuntimePermissions evalPersmissions) + { + var context = new TypeInferenceContext(powerShell); + return InferTypeOf(ast, context, evalPersmissions); + } + + /// + /// Infers the type that the result of executing a statement would have + /// + /// the ast to infer the type from + /// The current type inference context + /// The runtime usage permissions allowed during type inference + /// + internal static IList InferTypeOf(Ast ast, TypeInferenceContext context, TypeInferenceRuntimePermissions evalPersmissions = TypeInferenceRuntimePermissions.None) + { + var originalRuntimePermissions = context.RuntimePermissions; + try + { + context.RuntimePermissions = evalPersmissions; + return context.InferType(ast, new TypeInferenceVisitor(context)).ToList(); + } + finally + { + context.RuntimePermissions = originalRuntimePermissions; + } + } + } + + internal class TypeInferenceContext + { + public static readonly PSTypeName[] EmptyPSTypeNameArray = Utils.EmptyArray(); + private readonly PowerShell _powerShell; + + public TypeInferenceContext() : this(PowerShell.Create(RunspaceMode.CurrentRunspace)) + { + } + + /// + /// Create a new Type Inference context. + /// The powerShell instance passed need to have a non null Runspace + /// + /// + public TypeInferenceContext(PowerShell powerShell) + { + Diagnostics.Assert(powerShell.Runspace != null, "Callers are required to ensure we have a runspace"); + _powerShell = powerShell; + + Helper = new PowerShellExecutionHelper(powerShell); + } + + + public TypeDefinitionAst CurrentTypeDefinitionAst { get; set; } + + public TypeInferenceRuntimePermissions RuntimePermissions { get; set; } + + internal PowerShellExecutionHelper Helper { get; } + + internal ExecutionContext ExecutionContext => _powerShell.Runspace.ExecutionContext; + + public bool TryGetRepresentativeTypeNameFromExpressionSafeEval(ExpressionAst expression, out PSTypeName typeName) + { + typeName = null; + if (RuntimePermissions != TypeInferenceRuntimePermissions.AllowSafeEval) + { + return false; + } + object value; + return expression != null && + SafeExprEvaluator.TrySafeEval(expression, ExecutionContext, out value) && + TryGetRepresentativeTypeNameFromValue(value, out typeName); + } + + internal IList GetMembersByInferredType(PSTypeName typename, bool isStatic, Func filter) + { + List results = new List(); + + Func filterToCall = filter; + if (typename.Type != null) + { + AddMembersByInferredTypesClrType(typename, isStatic, filter, filterToCall, results); + } + else if (typename.TypeDefinitionAst != null) + { + AddMembersByInferredTypeDefinitionAst(typename, isStatic, filter, filterToCall, results); + } + else + { + // Look in the type table first. + if (!isStatic) + { + var consolidatedString = new ConsolidatedString(new[] { typename.Name }); + results.AddRange(ExecutionContext.TypeTable.GetMembers(consolidatedString)); + } + + AddMembersByInferredTypeCimType(typename, results, filterToCall); + } + + return results; + } + + + internal void AddMembersByInferredTypesClrType(PSTypeName typename, bool isStatic, Func filter, Func filterToCall, List results) + { + if (CurrentTypeDefinitionAst == null || CurrentTypeDefinitionAst.Type != typename.Type) + { + if (filterToCall == null) + { + filterToCall = o => !IsMemberHidden(o); + } + else + { + filterToCall = o => !IsMemberHidden(o) && filter(o); + } + } + IEnumerable elementTypes; + if (typename.Type.IsArray) + { + elementTypes = new[] { typename.Type.GetElementType() }; + } + else + { + elementTypes = typename.Type.GetInterfaces().Where( + t => t.GetTypeInfo().IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + } + foreach (var type in elementTypes.Prepend(typename.Type)) + { + // Look in the type table first. + if (!isStatic) + { + var consolidatedString = DotNetAdapter.GetInternedTypeNameHierarchy(type); + results.AddRange(ExecutionContext.TypeTable.GetMembers(consolidatedString)); + } + + var members = isStatic + ? PSObject.dotNetStaticAdapter.BaseGetMembers(type) + : PSObject.dotNetInstanceAdapter.GetPropertiesAndMethods(type, false); + results.AddRange(filterToCall != null ? members.Where(filterToCall) : members); + } + } + + internal void AddMembersByInferredTypeDefinitionAst(PSTypeName typename, bool isStatic, + Func filter, Func filterToCall, List results) + { + if (CurrentTypeDefinitionAst != typename.TypeDefinitionAst) + { + if (filterToCall == null) + filterToCall = o => !IsMemberHidden(o); + else + filterToCall = o => !IsMemberHidden(o) && filter(o); + } + + bool foundConstructor = false; + foreach (var member in typename.TypeDefinitionAst.Members) + { + bool add; + var propertyMember = member as PropertyMemberAst; + if (propertyMember != null) + { + add = propertyMember.IsStatic == isStatic; + } + else + { + var functionMember = (FunctionMemberAst) member; + add = functionMember.IsStatic == isStatic; + foundConstructor |= functionMember.IsConstructor; + } + + if (filterToCall != null && add) + { + add = filterToCall(member); + } + + if (add) + { + results.Add(member); + } + } + + //iterate through bases/interfaces + foreach (var baseType in typename.TypeDefinitionAst.BaseTypes) + { + var baseTypeName = baseType.TypeName as TypeName; + if (baseTypeName == null) continue; + var baseTypeDefinitionAst = baseTypeName._typeDefinitionAst; + results.AddRange(GetMembersByInferredType(new PSTypeName(baseTypeDefinitionAst), isStatic, filterToCall)); + } + + // Add stuff from our base class System.Object. + if (isStatic) + { + // Don't add base class constructors + if (filter == null) + { + filterToCall = o => !IsConstructor(o); + } + + else + { + filterToCall = o => !IsConstructor(o) && filter(o); + } + + if (!foundConstructor) + { + results.Add( + new CompilerGeneratedMemberFunctionAst(PositionUtilities.EmptyExtent, typename.TypeDefinitionAst, + SpecialMemberFunctionType.DefaultConstructor)); + } + } + else + { + // Reset the filter because the recursive call will add IsHidden back if necessary. + filterToCall = filter; + } + results.AddRange(GetMembersByInferredType(new PSTypeName(typeof(object)), isStatic, filterToCall)); + } + + + internal void AddMembersByInferredTypeCimType(PSTypeName typename, List results, Func filterToCall) + { + string cimNamespace; + string className; + if (ParseCimCommandsTypeName(typename, out cimNamespace, out className)) + { + var powerShellExecutionHelper = Helper; + powerShellExecutionHelper + .AddCommandWithPreferenceSetting("CimCmdlets\\Get-CimClass") + .AddParameter("Namespace", cimNamespace) + .AddParameter("Class", className); + var classes = powerShellExecutionHelper.ExecuteCurrentPowerShell(out _); + foreach (var cimClass in classes.Select(PSObject.Base).OfType()) + { + if (filterToCall == null) + { + results.AddRange(cimClass.CimClassProperties); + } + else + { + foreach (var prop in cimClass.CimClassProperties) + { + if (filterToCall(prop)) + { + results.Add(prop); + } + } + } + } + } + } + + internal IEnumerable InferType(Ast ast, TypeInferenceVisitor visitor) + { + var res = ast.Accept(visitor); + Diagnostics.Assert(res != null, "Fix visit methods to not return null"); + return (IEnumerable)res; + } + + private static bool TryGetRepresentativeTypeNameFromValue(object value, out PSTypeName type) + { + type = null; + if (value != null) + { + var list = value as IList; + if (list != null && list.Count > 0) + { + value = list[0]; + } + value = PSObject.Base(value); + if (value != null) + { + type = new PSTypeName(value.GetType()); + return true; + } + } + return false; + } + + internal static bool ParseCimCommandsTypeName(PSTypeName typename, out string cimNamespace, out string className) + { + cimNamespace = null; + className = null; + if (typename == null) + { + return false; + } + if (typename.Type != null) + { + return false; + } + + var match = Regex.Match(typename.Name, "(?.*)#(?.*)[/\\\\](?.*)"); + if (!match.Success) + { + return false; + } + + if (!match.Groups["NetTypeName"].Value.EqualsOrdinalIgnoreCase(typeof(CimInstance).FullName)) + { + return false; + } + + cimNamespace = match.Groups["CimNamespace"].Value; + className = match.Groups["CimClassName"].Value; + return true; + } + + private static bool IsMemberHidden(object member) + { + switch (member) + { + case PSMemberInfo psMemberInfo: + return psMemberInfo.IsHidden; + case MemberInfo memberInfo: + return memberInfo.GetCustomAttributes(typeof(HiddenAttribute), false).Any(); + case PropertyMemberAst propertyMember: + return propertyMember.IsHidden; + case FunctionMemberAst functionMember: + return functionMember.IsHidden; + } + + return false; + } + + private static bool IsConstructor(object member) + { + var psMethod = member as PSMethod; + var methodCacheEntry = psMethod?.adapterData as DotNetAdapter.MethodCacheEntry; + return methodCacheEntry != null && methodCacheEntry.methodInformationStructures[0].method.IsConstructor; + } + } + + internal class TypeInferenceVisitor : ICustomAstVisitor2 + { + private readonly TypeInferenceContext _context; + + private static readonly PSTypeName StringPSTypeName = new PSTypeName(typeof(string)); + + public TypeInferenceVisitor(TypeInferenceContext context) + { + _context = context; + } + + private IEnumerable InferTypes(Ast ast) + { + return _context.InferType(ast, this); + } + + object ICustomAstVisitor.VisitTypeExpression(TypeExpressionAst typeExpressionAst) + { + return new[] { new PSTypeName(typeExpressionAst.StaticType) }; + } + + object ICustomAstVisitor.VisitMemberExpression(MemberExpressionAst memberExpressionAst) + { + return InferTypesFrom(memberExpressionAst); + } + + object ICustomAstVisitor.VisitInvokeMemberExpression(InvokeMemberExpressionAst invokeMemberExpressionAst) + { + return InferTypesFrom(invokeMemberExpressionAst); + } + + object ICustomAstVisitor.VisitArrayExpression(ArrayExpressionAst arrayExpressionAst) + { + return new[] { new PSTypeName(typeof(object[])) }; + } + + object ICustomAstVisitor.VisitArrayLiteral(ArrayLiteralAst arrayLiteralAst) + { + return new[] { new PSTypeName(typeof(object[])) }; + } + + object ICustomAstVisitor.VisitHashtable(HashtableAst hashtableAst) + { + return new[] { new PSTypeName(typeof(Hashtable)) }; + } + + object ICustomAstVisitor.VisitScriptBlockExpression(ScriptBlockExpressionAst scriptBlockExpressionAst) + { + return new[] { new PSTypeName(typeof(ScriptBlock)) }; + } + + object ICustomAstVisitor.VisitParenExpression(ParenExpressionAst parenExpressionAst) + { + return parenExpressionAst.Pipeline.Accept(this); + } + + object ICustomAstVisitor.VisitExpandableStringExpression(ExpandableStringExpressionAst expandableStringExpressionAst) + { + return new[] { StringPSTypeName }; + } + + object ICustomAstVisitor.VisitIndexExpression(IndexExpressionAst indexExpressionAst) + { + return InferTypeFrom(indexExpressionAst); + } + + object ICustomAstVisitor.VisitAttributedExpression(AttributedExpressionAst attributedExpressionAst) + { + return attributedExpressionAst.Child.Accept(this); + } + + object ICustomAstVisitor.VisitBlockStatement(BlockStatementAst blockStatementAst) + { + return blockStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitUsingExpression(UsingExpressionAst usingExpressionAst) + { + return usingExpressionAst.SubExpression.Accept(this); + } + + object ICustomAstVisitor.VisitVariableExpression(VariableExpressionAst ast) + { + return InferTypeFrom(ast); + } + + object ICustomAstVisitor.VisitMergingRedirection(MergingRedirectionAst mergingRedirectionAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitBinaryExpression(BinaryExpressionAst binaryExpressionAst) + { + return InferTypes(binaryExpressionAst.Left); + } + + object ICustomAstVisitor.VisitUnaryExpression(UnaryExpressionAst unaryExpressionAst) + { + var tokenKind = unaryExpressionAst.TokenKind; + return (tokenKind == TokenKind.Not || tokenKind == TokenKind.Exclaim) + ? BinaryExpressionAst.BoolTypeNameArray + : unaryExpressionAst.Child.Accept(this); + } + + object ICustomAstVisitor.VisitConvertExpression(ConvertExpressionAst convertExpressionAst) + { + var type = convertExpressionAst.Type.TypeName.GetReflectionType(); + var psTypeName = type != null ? new PSTypeName(type) : new PSTypeName(convertExpressionAst.Type.TypeName.FullName); + return new[] { psTypeName }; + } + + object ICustomAstVisitor.VisitConstantExpression(ConstantExpressionAst constantExpressionAst) + { + var value = constantExpressionAst.Value; + return value != null ? new[] { new PSTypeName(value.GetType()) } : TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitStringConstantExpression(StringConstantExpressionAst stringConstantExpressionAst) + { + return new[] { StringPSTypeName }; + } + + object ICustomAstVisitor.VisitSubExpression(SubExpressionAst subExpressionAst) + { + return subExpressionAst.SubExpression.Accept(this); + } + + object ICustomAstVisitor.VisitErrorStatement(ErrorStatementAst errorStatementAst) + { + return errorStatementAst.Conditions.Concat(errorStatementAst.Bodies).Concat(errorStatementAst.NestedAst).SelectMany(InferTypes); + } + + object ICustomAstVisitor.VisitErrorExpression(ErrorExpressionAst errorExpressionAst) + { + return errorExpressionAst.NestedAst.SelectMany(InferTypes); + } + + object ICustomAstVisitor.VisitScriptBlock(ScriptBlockAst scriptBlockAst) + { + var res = new List(10); + var beginBlock = scriptBlockAst.BeginBlock; + var processBlock = scriptBlockAst.ProcessBlock; + var endBlock = scriptBlockAst.EndBlock; + // The following is used when we don't find OutputType, which is checked elsewhere. + if (beginBlock != null) + { + res.AddRange(InferTypes(beginBlock)); + } + if (processBlock != null) + { + res.AddRange(InferTypes(processBlock)); + } + if (endBlock != null) + { + res.AddRange(InferTypes(endBlock)); + } + return res; + } + + object ICustomAstVisitor.VisitParamBlock(ParamBlockAst paramBlockAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitNamedBlock(NamedBlockAst namedBlockAst) + { + return namedBlockAst.Statements.SelectMany(InferTypes); + } + + object ICustomAstVisitor.VisitTypeConstraint(TypeConstraintAst typeConstraintAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitAttribute(AttributeAst attributeAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitNamedAttributeArgument(NamedAttributeArgumentAst namedAttributeArgumentAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitParameter(ParameterAst parameterAst) + { + var res = new List(); + var attributes = parameterAst.Attributes; + var typeConstraint = attributes.OfType().FirstOrDefault(); + if (typeConstraint != null) + { + res.Add(new PSTypeName(typeConstraint.TypeName)); + } + foreach (var attributeAst in attributes.OfType()) + { + PSTypeNameAttribute attribute = null; + try + { + attribute = attributeAst.GetAttribute() as PSTypeNameAttribute; + } + catch (RuntimeException) { } + if (attribute != null) + { + res.Add(new PSTypeName(attribute.PSTypeName)); + } + } + return res; + } + + object ICustomAstVisitor.VisitFunctionDefinition(FunctionDefinitionAst functionDefinitionAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitStatementBlock(StatementBlockAst statementBlockAst) + { + return statementBlockAst.Statements.SelectMany(InferTypes); + } + + object ICustomAstVisitor.VisitIfStatement(IfStatementAst ifStmtAst) + { + var res = new List(); + + res.AddRange(ifStmtAst.Clauses.SelectMany(clause => InferTypes(clause.Item2))); + + var elseClause = ifStmtAst.ElseClause; + if (elseClause != null) + { + res.AddRange(InferTypes(elseClause)); + } + return res; + } + + object ICustomAstVisitor.VisitTrap(TrapStatementAst trapStatementAst) + { + return trapStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitSwitchStatement(SwitchStatementAst switchStatementAst) + { + var res = new List(8); + var clauses = switchStatementAst.Clauses; + var defaultStatement = switchStatementAst.Default; + + res.AddRange(clauses.SelectMany(clause => InferTypes(clause.Item2))); + + if (defaultStatement != null) + { + res.AddRange(InferTypes(defaultStatement)); + } + return res; + } + + object ICustomAstVisitor.VisitDataStatement(DataStatementAst dataStatementAst) + { + return dataStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitForEachStatement(ForEachStatementAst forEachStatementAst) + { + return forEachStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitDoWhileStatement(DoWhileStatementAst doWhileStatementAst) + { + return doWhileStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitForStatement(ForStatementAst forStatementAst) + { + return forStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitWhileStatement(WhileStatementAst whileStatementAst) + { + return whileStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitCatchClause(CatchClauseAst catchClauseAst) + { + return catchClauseAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitTryStatement(TryStatementAst tryStatementAst) + { + var res = new List(5); + res.AddRange(InferTypes(tryStatementAst.Body)); + foreach (var catchClauseAst in tryStatementAst.CatchClauses) + { + res.AddRange(InferTypes(catchClauseAst)); + } + if (tryStatementAst.Finally != null) + { + res.AddRange(InferTypes(tryStatementAst.Finally)); + } + return res; + } + + object ICustomAstVisitor.VisitBreakStatement(BreakStatementAst breakStatementAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitContinueStatement(ContinueStatementAst continueStatementAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitReturnStatement(ReturnStatementAst returnStatementAst) + { + return returnStatementAst.Pipeline.Accept(this); + } + + object ICustomAstVisitor.VisitExitStatement(ExitStatementAst exitStatementAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitThrowStatement(ThrowStatementAst throwStatementAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitDoUntilStatement(DoUntilStatementAst doUntilStatementAst) + { + return doUntilStatementAst.Body.Accept(this); + } + + object ICustomAstVisitor.VisitAssignmentStatement(AssignmentStatementAst assignmentStatementAst) + { + return assignmentStatementAst.Left.Accept(this); + } + + object ICustomAstVisitor.VisitPipeline(PipelineAst pipelineAst) + { + return pipelineAst.PipelineElements.Last().Accept(this); + } + + object ICustomAstVisitor.VisitCommand(CommandAst commandAst) + { + return InferTypesFrom(commandAst); + } + + object ICustomAstVisitor.VisitCommandExpression(CommandExpressionAst commandExpressionAst) + { + return commandExpressionAst.Expression.Accept(this); + } + + object ICustomAstVisitor.VisitCommandParameter(CommandParameterAst commandParameterAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor.VisitFileRedirection(FileRedirectionAst fileRedirectionAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + private IEnumerable InferTypesFrom(CommandAst commandAst) + { + PseudoBindingInfo pseudoBinding = new PseudoParameterBinder() + .DoPseudoParameterBinding(commandAst, null, null, PseudoParameterBinder.BindingType.ParameterCompletion); + if (pseudoBinding?.CommandInfo == null) + { + yield break; + } + + AstParameterArgumentPair pathArgument; + string pathParameterName = "Path"; + if (!pseudoBinding.BoundArguments.TryGetValue(pathParameterName, out pathArgument)) + { + pathParameterName = "LiteralPath"; + pseudoBinding.BoundArguments.TryGetValue(pathParameterName, out pathArgument); + } + + // The OutputType on cmdlets like Get-ChildItem may depend on the path. + // The CmdletInfo returned based on just the command name will specify returning all possibilities, e.g.certificates, environment, registry, etc. + // If you specified - Path, the list of OutputType can be refined, but we have to make a copy of the CmdletInfo object this way to get that refinement. + var commandInfo = pseudoBinding.CommandInfo; + var pathArgumentPair = pathArgument as AstPair; + if (pathArgumentPair?.Argument is StringConstantExpressionAst) + { + var pathValue = ((StringConstantExpressionAst)pathArgumentPair.Argument).Value; + try + { + commandInfo = commandInfo.CreateGetCommandCopy(new object[] { "-" + pathParameterName, pathValue }); + } + catch (InvalidOperationException) { } + } + + var cmdletInfo = commandInfo as CmdletInfo; + if (cmdletInfo != null) + { + // Special cases + var inferTypesFromObjectCmdlets = InferTypesFromObjectCmdlets(commandAst, cmdletInfo, pseudoBinding).ToArray(); + if (inferTypesFromObjectCmdlets.Length > 0) + { + foreach (var objectCmdletTypes in inferTypesFromObjectCmdlets) + { + yield return objectCmdletTypes; + } + yield break; + } + } + + // The OutputType property ignores the parameter set specified in the OutputTypeAttribute. + // With psuedo-binding, we actually know the candidate parameter sets, so we could take + // advantage of it here, but I opted for the simpler code because so few cmdlets use + // ParameterSetName in OutputType and of the ones I know about, it isn't that useful. + foreach (var result in commandInfo.OutputType) + { + yield return result; + } + } + + /// + /// Infer types from the well-known object cmdlets, like foreach-object, where-object, sort-object etc + /// + /// + /// + /// + /// + private IEnumerable InferTypesFromObjectCmdlets(CommandAst commandAst, CmdletInfo cmdletInfo, PseudoBindingInfo pseudoBinding) + { + // new-object - yields an instance of whatever -Type is bound to + if (cmdletInfo.ImplementingType.FullName.EqualsOrdinalIgnoreCase("Microsoft.PowerShell.Commands.NewObjectCommand")) + { + var newObjectType = InferTypesFromNewObjectCommand(pseudoBinding); + if (newObjectType != null) + { + yield return newObjectType; + } + + yield break; // yield break; + } + + // Get-CimInstance/New-CimInstance - yields a CimInstance with ETS type based on its arguments for -Namespace and -ClassName parameters + if ( + cmdletInfo.ImplementingType.FullName.EqualsOrdinalIgnoreCase("Microsoft.Management.Infrastructure.CimCmdlets.GetCimInstanceCommand") || + cmdletInfo.ImplementingType.FullName.EqualsOrdinalIgnoreCase("Microsoft.Management.Infrastructure.CimCmdlets.NewCimInstanceCommand")) + { + foreach (var cimType in InferTypesFromCimCommand(pseudoBinding)) + { + yield return cimType; + } + yield break; // yield break; + } + + // where-object - yields whatever we saw before where-object in the pipeline. + // same for sort-object + if (cmdletInfo.ImplementingType == typeof(WhereObjectCommand) + || + cmdletInfo.ImplementingType.FullName.EqualsOrdinalIgnoreCase("Microsoft.PowerShell.Commands.SortObjectCommand")) + { + foreach (var whereOrSortType in InferTypesFromWhereAndSortCommand(commandAst)) + { + yield return whereOrSortType; + } + + // We could also check -InputObject, but that is rarely used. But don't bother continuing. + yield break; // yield break; + } + + // foreach-object - yields the type of it's script block parameters + if (cmdletInfo.ImplementingType == typeof(ForEachObjectCommand)) + { + foreach (var foreachType in InferTypesFromForeachCommand(pseudoBinding)) + { + yield return foreachType; + } + } + } + + private static IEnumerable InferTypesFromCimCommand(PseudoBindingInfo pseudoBinding) + { + string pseudoboundNamespace = + CompletionCompleters.NativeCommandArgumentCompletion_ExtractSecondaryArgument(pseudoBinding.BoundArguments, + "Namespace").FirstOrDefault(); + string pseudoboundClassName = + CompletionCompleters.NativeCommandArgumentCompletion_ExtractSecondaryArgument(pseudoBinding.BoundArguments, + "ClassName").FirstOrDefault(); + if (!string.IsNullOrWhiteSpace(pseudoboundClassName)) + { + yield return new PSTypeName(string.Format( + CultureInfo.InvariantCulture, + "{0}#{1}/{2}", + typeof(CimInstance).FullName, + pseudoboundNamespace ?? "root/cimv2", + pseudoboundClassName)); + } + yield return new PSTypeName(typeof(CimInstance)); + } + + private IEnumerable InferTypesFromForeachCommand(PseudoBindingInfo pseudoBinding) + { + AstParameterArgumentPair argument; + if (pseudoBinding.BoundArguments.TryGetValue("Begin", out argument)) + { + foreach (var type in GetInferredTypeFromScriptBlockParameter(argument)) + { + yield return type; + } + } + + if (pseudoBinding.BoundArguments.TryGetValue("Process", out argument)) + { + foreach (var type in GetInferredTypeFromScriptBlockParameter(argument)) + { + yield return type; + } + } + + if (pseudoBinding.BoundArguments.TryGetValue("End", out argument)) + { + foreach (var type in GetInferredTypeFromScriptBlockParameter(argument)) + { + yield return type; + } + } + } + + private IEnumerable InferTypesFromWhereAndSortCommand(CommandAst commandAst) + { + var parentPipeline = commandAst.Parent as PipelineAst; + if (parentPipeline != null) + { + int i; + for (i = 0; i < parentPipeline.PipelineElements.Count; i++) + { + if (parentPipeline.PipelineElements[i] == commandAst) + break; + } + if (i > 0) + { + foreach (var typename in InferTypes(parentPipeline.PipelineElements[i - 1])) + { + yield return typename; + } + } + } + } + + private static PSTypeName InferTypesFromNewObjectCommand(PseudoBindingInfo pseudoBinding) + { + AstParameterArgumentPair typeArgument; + if (pseudoBinding.BoundArguments.TryGetValue("TypeName", out typeArgument)) + { + var typeArgumentPair = typeArgument as AstPair; + var stringConstantExpr = typeArgumentPair?.Argument as StringConstantExpressionAst; + if (stringConstantExpr != null) + { + return new PSTypeName(stringConstantExpr.Value); + } + } + return null; + } + + private IEnumerable InferTypesFrom(MemberExpressionAst memberExpressionAst) + { + var memberCommandElement = memberExpressionAst.Member; + var isStatic = memberExpressionAst.Static; + var expression = memberExpressionAst.Expression; + + // If the member name isn't simple, don't even try. + var memberAsStringConst = memberCommandElement as StringConstantExpressionAst; + if (memberAsStringConst == null) + yield break; + + var exprType = GetExpressionType(expression, isStatic); + if (exprType == null || exprType.Length == 0) + { + yield break; + } + + var maybeWantDefaultCtor = isStatic + && memberExpressionAst is InvokeMemberExpressionAst + && memberAsStringConst.Value.EqualsOrdinalIgnoreCase("new"); + + // We use a list of member names because we might discover aliases properties + // and if we do, we'll add to the list. + var memberNameList = new List { memberAsStringConst.Value }; + foreach (var type in exprType) + { + var members = _context.GetMembersByInferredType(type, isStatic, filter: null); + + for (int i = 0; i < memberNameList.Count; i++) + { + string memberName = memberNameList[i]; + foreach (var member in members) + { + var isInvokeMemberAst = memberExpressionAst is InvokeMemberExpressionAst; + switch (member) + { + case PropertyInfo propertyInfo: // .net property + { + if (propertyInfo.Name.EqualsOrdinalIgnoreCase(memberName) && !isInvokeMemberAst) + { + yield return new PSTypeName(propertyInfo.PropertyType); + goto NextMember; + } + continue; + } + case FieldInfo fieldInfo: // .net field + { + if (fieldInfo.Name.EqualsOrdinalIgnoreCase(memberName) && !isInvokeMemberAst) + { + yield return new PSTypeName(fieldInfo.FieldType); + } + continue; + } + + case DotNetAdapter.MethodCacheEntry methodCacheEntry: // .net method + { + if (methodCacheEntry[0].method.Name.EqualsOrdinalIgnoreCase(memberName)) + { + maybeWantDefaultCtor = false; + if (isInvokeMemberAst) + { + foreach (var method in methodCacheEntry.methodInformationStructures) + { + var methodInfo = method.method as MethodInfo; + if (methodInfo != null && !methodInfo.ReturnType.GetTypeInfo().ContainsGenericParameters) + { + yield return new PSTypeName(methodInfo.ReturnType); + } + } + } + else + { + // Accessing a method as a property, we'd return a wrapper over the method. + yield return new PSTypeName(typeof(PSMethod)); + } + } + continue; + } + case MemberAst memberAst: // this is for members defined by PowerShell classes + { + if (memberAst.Name.EqualsOrdinalIgnoreCase(memberName)) + { + if (isInvokeMemberAst) + { + var functionMemberAst = memberAst as FunctionMemberAst; + if (functionMemberAst != null && !functionMemberAst.IsReturnTypeVoid()) + { + yield return new PSTypeName(functionMemberAst.ReturnType.TypeName); + } + } + else + { + var propertyMemberAst = memberAst as PropertyMemberAst; + if (propertyMemberAst != null) + { + if (propertyMemberAst.PropertyType != null) + { + yield return new PSTypeName(propertyMemberAst.PropertyType.TypeName); + } + else + { + yield return new PSTypeName(typeof(object)); + } + } + else + { + // Accessing a method as a property, we'd return a wrapper over the method. + yield return new PSTypeName(typeof(PSMethod)); + } + } + } + continue; + } + case PSMemberInfo memberInfo: + { + if (!memberInfo.Name.EqualsOrdinalIgnoreCase(memberName)) + { + continue; + } + switch (member) + { + case PSProperty p: + { + yield return new PSTypeName(p.Value.GetType()); + goto NextMember; + } + case PSNoteProperty noteProperty: + { + yield return new PSTypeName(noteProperty.Value.GetType()); + goto NextMember; + } + case PSAliasProperty aliasProperty: + { + memberNameList.Add(aliasProperty.ReferencedMemberName); + goto NextMember; + } + case PSCodeProperty codeProperty: + { + if (codeProperty.GetterCodeReference != null) + { + yield return new PSTypeName(codeProperty.GetterCodeReference.ReturnType); + } + goto NextMember; + } + case PSScriptProperty scriptProperty: + { + var scriptBlock = scriptProperty.GetterScript; + foreach (var t in scriptBlock.OutputType) + { + yield return t; + } + goto NextMember; + } + case PSScriptMethod scriptMethod: + { + var scriptBlock = scriptMethod.Script; + foreach (var t in scriptBlock.OutputType) + { + yield return t; + } + goto NextMember; + } + } + break; + } + } + } + NextMember: {} + } + + // We didn't find any constructors but they used [T]::new() syntax + if (maybeWantDefaultCtor) + { + yield return type; + } + } + } + + private PSTypeName[] GetExpressionType(ExpressionAst expression, bool isStatic) + { + PSTypeName[] exprType; + if (isStatic) + { + var exprAsType = expression as TypeExpressionAst; + if (exprAsType == null) + return null; + var type = exprAsType.TypeName.GetReflectionType(); + if (type == null) + { + var typeName = exprAsType.TypeName as TypeName; + if (typeName?._typeDefinitionAst == null) + return null; + + exprType = new[] {new PSTypeName(typeName._typeDefinitionAst)}; + } + else + { + exprType = new[] {new PSTypeName(type)}; + } + } + else + { + exprType = InferTypes(expression).ToArray(); + if (exprType.Length == 0) + { + if (_context.TryGetRepresentativeTypeNameFromExpressionSafeEval(expression, out PSTypeName _)) + { + return exprType; + } + return exprType; + } + } + return exprType; + } + + private IEnumerable InferTypeFrom(VariableExpressionAst variableExpressionAst) + { + // We don't need to handle drive qualified variables, we can usually get those values + // without needing to "guess" at the type. + var astVariablePath = variableExpressionAst.VariablePath; + if (!astVariablePath.IsVariable) + { + // Not a variable - the caller should have already tried going to session state + // to get the item and hence it's type, but that must have failed. Don't try again. + yield break; + } + + Ast parent = variableExpressionAst.Parent; + if (astVariablePath.IsUnqualified && + (SpecialVariables.IsUnderbar(astVariablePath.UserPath) + || astVariablePath.UserPath.EqualsOrdinalIgnoreCase(SpecialVariables.PSItem))) + { + // $_ is special, see if we're used in a script block in some pipeline. + while (parent != null) + { + if (parent is ScriptBlockExpressionAst) + break; + parent = parent.Parent; + } + + if (parent != null) + { + if (parent.Parent is CommandExpressionAst && parent.Parent.Parent is PipelineAst) + { + // Script block in a hash table, could be something like: + // dir | ft @{ Expression = { $_ } } + if (parent.Parent.Parent.Parent is HashtableAst) + { + parent = parent.Parent.Parent.Parent; + } + else if (parent.Parent.Parent.Parent is ArrayLiteralAst && parent.Parent.Parent.Parent.Parent is HashtableAst) + { + parent = parent.Parent.Parent.Parent.Parent; + } + } + if (parent.Parent is CommandParameterAst) + { + parent = parent.Parent; + } + + var commandAst = parent.Parent as CommandAst; + if (commandAst != null) + { + // We found a command, see if there is a previous command in the pipeline. + PipelineAst pipelineAst = (PipelineAst)commandAst.Parent; + var previousCommandIndex = pipelineAst.PipelineElements.IndexOf(commandAst) - 1; + if (previousCommandIndex < 0) yield break; + foreach (var result in InferTypes(pipelineAst.PipelineElements[0])) + { + if (result.Type != null) + { + // Assume (because we're looking at $_ and we're inside a script block that is an + // argument to some command) that the type we're getting is actually unrolled. + // This might not be right in all cases, but with our simple analysis, it's + // right more often than it's wrong. + if (result.Type.IsArray) + { + yield return new PSTypeName(result.Type.GetElementType()); + continue; + } + + if (typeof(IEnumerable).IsAssignableFrom(result.Type)) + { + // We can't deduce much from IEnumerable, but we can if it's generic. + var enumerableInterfaces = result.Type.GetInterfaces().Where( + t => + t.GetTypeInfo().IsGenericType && + t.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + foreach (var i in enumerableInterfaces) + { + yield return new PSTypeName(i.GetGenericArguments()[0]); + } + continue; + } + } + yield return result; + } + yield break; + } + } + } + + // For certain variables, we always know their type, well at least we can assume we know. + if (astVariablePath.IsUnqualified) + { + if (!astVariablePath.UserPath.EqualsOrdinalIgnoreCase(SpecialVariables.This) || + _context.CurrentTypeDefinitionAst == null) + { + for (int i = 0; i < SpecialVariables.AutomaticVariables.Length; i++) + { + if (!astVariablePath.UserPath.EqualsOrdinalIgnoreCase(SpecialVariables.AutomaticVariables[i])) + continue; + var type = SpecialVariables.AutomaticVariableTypes[i]; + if (type != typeof(object)) + yield return new PSTypeName(type); + break; + } + } + else + { + yield return new PSTypeName(_context.CurrentTypeDefinitionAst); + yield break; + } + } + + // Look for our variable as a parameter or on the lhs of an assignment - hopefully we'll find either + // a type constraint or at least we can use the rhs to infer the type. + + while (parent.Parent != null) + { + parent = parent.Parent; + } + + if (parent.Parent is FunctionDefinitionAst) + { + parent = parent.Parent; + } + + int startOffset = variableExpressionAst.Extent.StartOffset; + var targetAsts = (List)AstSearcher.FindAll(parent, + ast => (ast is ParameterAst || ast is AssignmentStatementAst || ast is ForEachStatementAst || ast is CommandAst) + && variableExpressionAst.AstAssignsToSameVariable(ast) + && ast.Extent.EndOffset < startOffset, + searchNestedScriptBlocks: true); + + var parameterAst = targetAsts.OfType().FirstOrDefault(); + if (parameterAst != null) + { + var parameterTypes = InferTypes(parameterAst).ToArray(); + if (parameterTypes.Length > 0) + { + foreach (var parameterType in parameterTypes) + { + yield return parameterType; + } + yield break; + } + } + + var assignAsts = targetAsts.OfType().ToArray(); + + // If any of the assignments lhs use a type constraint, then we use that. + // Otherwise, we use the rhs of the "nearest" assignment + foreach (var assignAst in assignAsts) + { + var lhsConvert = assignAst.Left as ConvertExpressionAst; + if (lhsConvert != null) + { + yield return new PSTypeName(lhsConvert.Type.TypeName); + yield break; + } + } + + var foreachAst = targetAsts.OfType().FirstOrDefault(); + if (foreachAst != null) + { + foreach (var typeName in InferTypes(foreachAst.Condition)) + { + yield return typeName; + } + yield break; + } + + var commandCompletionAst = targetAsts.OfType().FirstOrDefault(); + if (commandCompletionAst != null) + { + foreach (var typeName in InferTypes(commandCompletionAst)) + { + yield return typeName; + } + yield break; + } + + int smallestDiff = int.MaxValue; + AssignmentStatementAst closestAssignment = null; + foreach (var assignAst in assignAsts) + { + var endOffset = assignAst.Extent.EndOffset; + if ((startOffset - endOffset) < smallestDiff) + { + smallestDiff = startOffset - endOffset; + closestAssignment = assignAst; + } + } + + if (closestAssignment != null) + { + foreach (var type in InferTypes(closestAssignment.Right)) + { + yield return type; + } + } + + + PSTypeName evalTypeName; + if (_context.TryGetRepresentativeTypeNameFromExpressionSafeEval(variableExpressionAst, out evalTypeName)) + { + yield return evalTypeName; + } + + } + + private IEnumerable InferTypeFrom(IndexExpressionAst indexExpressionAst) + { + var targetTypes = InferTypes(indexExpressionAst.Target); + bool foundAny = false; + foreach (var psType in targetTypes) + { + var type = psType.Type; + if (type != null) + { + if (type.IsArray) + { + yield return new PSTypeName(type.GetElementType()); + continue; + } + + foreach (var iface in type.GetInterfaces()) + { + var isGenericType = iface.GetTypeInfo().IsGenericType; + if (isGenericType && iface.GetGenericTypeDefinition() == typeof(IDictionary<,>)) + { + var valueType = iface.GetGenericArguments()[1]; + if (!valueType.GetTypeInfo().ContainsGenericParameters) + { + foundAny = true; + yield return new PSTypeName(valueType); + } + } + else if (isGenericType && iface.GetGenericTypeDefinition() == typeof(IList<>)) + { + var valueType = iface.GetGenericArguments()[0]; + if (!valueType.GetTypeInfo().ContainsGenericParameters) + { + foundAny = true; + yield return new PSTypeName(valueType); + } + } + } + + var defaultMember = type.GetCustomAttributes(true).FirstOrDefault(); + if (defaultMember != null) + { + var indexers = type.GetGetterProperty(defaultMember.MemberName); + foreach (var indexer in indexers) + { + foundAny = true; + yield return new PSTypeName(indexer.ReturnType); + } + } + } + + if (!foundAny) + { + // Inferred type of target wasn't indexable. Assume (perhaps incorrectly) + // that it came from OutputType and that more than one object was returned + // and that we're indexing because of that, in which case, OutputType really + // is the inferred type. + yield return psType; + } + } + } + + private IEnumerable GetInferredTypeFromScriptBlockParameter(AstParameterArgumentPair argument) + { + var argumentPair = argument as AstPair; + var scriptBlockExpressionAst = argumentPair?.Argument as ScriptBlockExpressionAst; + if (scriptBlockExpressionAst == null) yield break; + foreach (var type in InferTypes(scriptBlockExpressionAst.ScriptBlock)) + { + yield return type; + } + } + + object ICustomAstVisitor2.VisitTypeDefinition(TypeDefinitionAst typeDefinitionAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor2.VisitPropertyMember(PropertyMemberAst propertyMemberAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor2.VisitFunctionMember(FunctionMemberAst functionMemberAst) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor2.VisitBaseCtorInvokeMemberExpression(BaseCtorInvokeMemberExpressionAst baseCtorInvokeMemberExpressionAst) + { + return ((ICustomAstVisitor)this).VisitInvokeMemberExpression(baseCtorInvokeMemberExpressionAst); + } + + object ICustomAstVisitor2.VisitUsingStatement(UsingStatementAst usingStatement) + { + return TypeInferenceContext.EmptyPSTypeNameArray; + } + + object ICustomAstVisitor2.VisitConfigurationDefinition(ConfigurationDefinitionAst configurationDefinitionAst) + { + return configurationDefinitionAst.Body.Accept(this); + } + + object ICustomAstVisitor2.VisitDynamicKeywordStatement(DynamicKeywordStatementAst dynamicKeywordAst) + { + // TODO: What is the right InferredType for the AST + return dynamicKeywordAst.CommandElements[0].Accept(this); + } + } + + static class TypeInferenceExtension + { + public static bool EqualsOrdinalIgnoreCase(this string s, string t) + { + return string.Equals(s, t, StringComparison.OrdinalIgnoreCase); + } + + public static IEnumerable GetGetterProperty(this Type type, string propertyName) + { + return type.GetMethods(BindingFlags.Public | BindingFlags.Instance).Where( + m => + { + var name = m.Name; + // Equals without string allocation + return name.Length == propertyName.Length + 4 && + name.StartsWith("get_") && propertyName.IndexOf(name, 4, StringComparison.Ordinal) == 4; + } + ); + } + + public static bool AstAssignsToSameVariable(this VariableExpressionAst variableAst, Ast ast) + { + var parameterAst = ast as ParameterAst; + var variableAstVariablePath = variableAst.VariablePath; + if (parameterAst != null) + { + return variableAstVariablePath.IsUnscopedVariable && + parameterAst.Name.VariablePath.UnqualifiedPath.Equals(variableAstVariablePath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase); + } + + var foreachAst = ast as ForEachStatementAst; + if (foreachAst != null) + { + return variableAstVariablePath.IsUnscopedVariable && + foreachAst.Variable.VariablePath.UnqualifiedPath.Equals(variableAstVariablePath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase); + } + + var commandAst = ast as CommandAst; + if (commandAst != null) + { + string[] variableParameters = new string[] { "PV", "PipelineVariable", "OV", "OutVariable" }; + StaticBindingResult bindingResult = StaticParameterBinder.BindCommand(commandAst, false, variableParameters); + + if (bindingResult != null) + { + foreach (string commandVariableParameter in variableParameters) + { + if (bindingResult.BoundParameters.TryGetValue(commandVariableParameter, out ParameterBindingResult parameterBindingResult)) + { + if (String.Equals(variableAstVariablePath.UnqualifiedPath, (string)parameterBindingResult.ConstantValue, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + } + } + + return false; + } + + var assignmentAst = (AssignmentStatementAst)ast; + var lhs = assignmentAst.Left; + var convertExpr = lhs as ConvertExpressionAst; + if (convertExpr != null) + { + lhs = convertExpr.Child; + } + + var varExpr = lhs as VariableExpressionAst; + if (varExpr == null) + return false; + + var candidateVarPath = varExpr.VariablePath; + if (candidateVarPath.UserPath.Equals(variableAstVariablePath.UserPath, StringComparison.OrdinalIgnoreCase)) + return true; + + // The following condition is making an assumption that at script scope, we didn't use $script:, but in the local scope, we did + // If we are searching anything other than script scope, this is wrong. + if (variableAstVariablePath.IsScript && variableAstVariablePath.UnqualifiedPath.Equals(candidateVarPath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase)) + return true; + + return false; + } + } +} diff --git a/src/System.Management.Automation/engine/parser/VariableAnalysis.cs b/src/System.Management.Automation/engine/parser/VariableAnalysis.cs index 7e6fa1734c0..f8d85c16e02 100644 --- a/src/System.Management.Automation/engine/parser/VariableAnalysis.cs +++ b/src/System.Management.Automation/engine/parser/VariableAnalysis.cs @@ -457,12 +457,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) Diagnostics.Assert(false, "This code is unreachable."); return AstVisitAction.Continue; } - - internal override IEnumerable GetInferredType(CompletionContext context) - { - Diagnostics.Assert(false, "This code is unreachable."); - return Ast.EmptyPSTypeNameArray; - } } internal static string GetUnaliasedVariableName(string varName) diff --git a/src/System.Management.Automation/engine/parser/ast.cs b/src/System.Management.Automation/engine/parser/ast.cs index f6a187bbb61..aed4cbd523d 100644 --- a/src/System.Management.Automation/engine/parser/ast.cs +++ b/src/System.Management.Automation/engine/parser/ast.cs @@ -281,8 +281,6 @@ internal void ClearParent() internal abstract object Accept(ICustomAstVisitor visitor); internal abstract AstVisitAction InternalVisit(AstVisitor visitor); - internal abstract IEnumerable GetInferredType(CompletionContext context); - internal static PSTypeName[] EmptyPSTypeNameArray = Utils.EmptyArray(); internal bool IsInWorkflow() @@ -422,12 +420,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) Diagnostics.Assert(false, "code should be unreachable"); return visitor.CheckForPostAction(this, AstVisitAction.Continue); } - - internal override IEnumerable GetInferredType(CompletionContext context) - { - Diagnostics.Assert(false, "code should be unreachable"); - return Ast.EmptyPSTypeNameArray; - } } /// @@ -567,11 +559,6 @@ public override Ast Copy() } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Conditions.Concat(Bodies).Concat(NestedAst).SelectMany(nestedAst => nestedAst.GetInferredType(context)); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -662,11 +649,6 @@ public override Ast Copy() return new ErrorExpressionAst(this.Extent, newNestedAst); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return NestedAst.SelectMany(nestedAst => nestedAst.GetInferredType(context)); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -1300,32 +1282,6 @@ internal void PerformPostParseChecks(Parser parser) Diagnostics.Assert(PostParseChecksPerformed, "Post parse checks not set during semantic checks"); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - // The following is used when we don't find OutputType, which is checked elsewhere. - if (BeginBlock != null) - { - foreach (var typename in BeginBlock.GetInferredType(context)) - { - yield return typename; - } - } - if (ProcessBlock != null) - { - foreach (var typename in ProcessBlock.GetInferredType(context)) - { - yield return typename; - } - } - if (EndBlock != null) - { - foreach (var typename in EndBlock.GetInferredType(context)) - { - yield return typename; - } - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -1626,11 +1582,6 @@ public override Ast Copy() return new ParamBlockAst(this.Extent, newAttributes, newParameters); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -1819,11 +1770,6 @@ public override Ast Copy() internal IScriptExtent OpenCurlyExtent { get; private set; } internal IScriptExtent CloseCurlyExtent { get; private set; } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Statements.SelectMany(ast => ast.GetInferredType(context)); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -1911,11 +1857,6 @@ public override Ast Copy() return new NamedAttributeArgumentAst(this.Extent, this.ArgumentName, newArgument, this.ExpressionOmitted); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -1967,11 +1908,6 @@ protected AttributeBaseAst(IScriptExtent extent, ITypeName typeName) public ITypeName TypeName { get; private set; } internal abstract Attribute GetAttribute(); - - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } } /// @@ -2234,28 +2170,6 @@ public override Ast Copy() return new ParameterAst(this.Extent, newName, newAttributes, newDefaultValue); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - var typeConstraint = Attributes.OfType().FirstOrDefault(); - if (typeConstraint != null) - { - yield return new PSTypeName(typeConstraint.TypeName); - } - foreach (var attributeAst in Attributes.OfType()) - { - PSTypeNameAttribute attribute = null; - try - { - attribute = attributeAst.GetAttribute() as PSTypeNameAttribute; - } - catch (RuntimeException) { } - if (attribute != null) - { - yield return new PSTypeName(attribute.PSTypeName); - } - } - } - internal string GetTooltip() { var typeConstraint = Attributes.OfType().FirstOrDefault(); @@ -2418,11 +2332,6 @@ public override Ast Copy() return new StatementBlockAst(this.Extent, newStatements, newTraps); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Statements.SelectMany(nestedAst => nestedAst.GetInferredType(context)); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -2704,11 +2613,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) return visitor.CheckForPostAction(this, action); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #endregion Visitors } @@ -2932,11 +2836,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) #endregion - internal override IEnumerable GetInferredType(CompletionContext context) - { - throw new NotImplementedException(); - } - /// /// Define imported module and all type definitions imported by this using statement. /// @@ -3163,11 +3062,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) return visitor.CheckForPostAction(this, action); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - throw new NotImplementedException(); - } - #endregion Visitors } @@ -3383,11 +3277,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) return visitor.CheckForPostAction(this, action); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - throw new NotImplementedException(); - } - #endregion Visitors #region IParameterMetadataProvider implementation @@ -3510,12 +3399,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) return AstVisitAction.Continue; } - internal override IEnumerable GetInferredType(CompletionContext context) - { - Diagnostics.Assert(false, "code should be unreachable"); - return Ast.EmptyPSTypeNameArray; - } - public bool HasAnyScriptBlockAttributes() { return ((IParameterMetadataProvider)Body).HasAnyScriptBlockAttributes(); @@ -3716,11 +3599,6 @@ public override Ast Copy() return new FunctionDefinitionAst(this.Extent, this.IsFilter, this.IsWorkflow, this.Name, newParameters, newBody) { NameExtent = this.NameExtent }; } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - internal string GetParamTextFromParameterList(Tuple, string> usingVariablesTuple = null) { Diagnostics.Assert(Parameters != null, "Caller makes sure that Parameters is not null before calling this method."); @@ -3935,22 +3813,6 @@ public override Ast Copy() return new IfStatementAst(this.Extent, newClauses, newElseClause); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - foreach (var typename in Clauses.SelectMany(clause => clause.Item2.GetInferredType(context))) - { - yield return typename; - } - - if (ElseClause != null) - { - foreach (var typename in ElseClause.GetInferredType(context)) - { - yield return typename; - } - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -4057,11 +3919,6 @@ public override Ast Copy() return new DataStatementAst(this.Extent, this.Variable, newCommandsAllowed, newBody); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Body.GetInferredType(context); - } - internal bool HasNonConstantAllowedCommand { get; private set; } #region Visitors @@ -4166,11 +4023,6 @@ protected LoopStatementAst(IScriptExtent extent, string label, PipelineBaseAst c /// The body of a loop statement. This property is never null. /// public StatementBlockAst Body { get; private set; } - - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Body.GetInferredType(context); - } } /// @@ -4714,22 +4566,6 @@ public override Ast Copy() return new SwitchStatementAst(this.Extent, this.Label, newCondition, this.Flags, newClauses, newDefault); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - foreach (var typename in Clauses.SelectMany(clause => clause.Item2.GetInferredType(context))) - { - yield return typename; - } - - if (Default != null) - { - foreach (var typename in Default.GetInferredType(context)) - { - yield return typename; - } - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -4833,11 +4669,6 @@ public override Ast Copy() return new CatchClauseAst(this.Extent, newCatchTypes, newBody); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Body.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -4952,27 +4783,6 @@ public override Ast Copy() return new TryStatementAst(this.Extent, newBody, newCatchClauses, newFinally); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - foreach (var typename in Body.GetInferredType(context)) - { - yield return typename; - } - - foreach (var typename in CatchClauses.SelectMany(clause => clause.Body.GetInferredType(context))) - { - yield return typename; - } - - if (Finally != null) - { - foreach (var typename in Finally.GetInferredType(context)) - { - yield return typename; - } - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5057,11 +4867,6 @@ public override Ast Copy() return new TrapStatementAst(this.Extent, newTrapType, newBody); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Body.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5125,11 +4930,6 @@ public override Ast Copy() return new BreakStatementAst(this.Extent, newLabel); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5187,11 +4987,6 @@ public override Ast Copy() return new ContinueStatementAst(this.Extent, newLabel); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5249,11 +5044,6 @@ public override Ast Copy() return new ReturnStatementAst(this.Extent, newPipeline); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5311,11 +5101,6 @@ public override Ast Copy() return new ExitStatementAst(this.Extent, newPipeline); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5405,11 +5190,6 @@ public override Ast Copy() return new ThrowStatementAst(this.Extent, newPipeline); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5543,11 +5323,6 @@ public override Ast Copy() return new PipelineAst(this.Extent, newPipelineElements); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return PipelineElements.Last().GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5670,11 +5445,6 @@ public override Ast Copy() return new CommandParameterAst(this.Extent, this.ParameterName, newArgument, this.ErrorPosition); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -5826,154 +5596,6 @@ public override Ast Copy() }; } - internal override IEnumerable GetInferredType(CompletionContext context) - { - PseudoBindingInfo pseudoBinding = new PseudoParameterBinder() - .DoPseudoParameterBinding(this, null, null, PseudoParameterBinder.BindingType.ParameterCompletion); - if (pseudoBinding == null || pseudoBinding.CommandInfo == null) - { - yield break; - } - - AstParameterArgumentPair pathArgument; - string pathParameterName = "Path"; - if (!pseudoBinding.BoundArguments.TryGetValue(pathParameterName, out pathArgument)) - { - pathParameterName = "LiteralPath"; - pseudoBinding.BoundArguments.TryGetValue(pathParameterName, out pathArgument); - } - - var commandInfo = pseudoBinding.CommandInfo; - var pathArgumentPair = pathArgument as AstPair; - if (pathArgumentPair != null && pathArgumentPair.Argument is StringConstantExpressionAst) - { - var pathValue = ((StringConstantExpressionAst)pathArgumentPair.Argument).Value; - try - { - commandInfo = commandInfo.CreateGetCommandCopy(new[] { "-" + pathParameterName, pathValue }); - } - catch (InvalidOperationException) { } - } - - var cmdletInfo = commandInfo as CmdletInfo; - if (cmdletInfo != null) - { - // Special cases - - // new-object - yields an instance of whatever -Type is bound to - if (cmdletInfo.ImplementingType.FullName.Equals("Microsoft.PowerShell.Commands.NewObjectCommand", StringComparison.Ordinal)) - { - AstParameterArgumentPair typeArgument; - if (pseudoBinding.BoundArguments.TryGetValue("TypeName", out typeArgument)) - { - var typeArgumentPair = typeArgument as AstPair; - if (typeArgumentPair != null && typeArgumentPair.Argument is StringConstantExpressionAst) - { - yield return new PSTypeName(((StringConstantExpressionAst)typeArgumentPair.Argument).Value); - } - } - yield break; - } - - // Get-CimInstance/New-CimInstance - yields a CimInstance with ETS type based on its arguments for -Namespace and -ClassName parameters - if (cmdletInfo.ImplementingType.FullName.Equals("Microsoft.Management.Infrastructure.CimCmdlets.GetCimInstanceCommand", StringComparison.Ordinal) || - cmdletInfo.ImplementingType.FullName.Equals("Microsoft.Management.Infrastructure.CimCmdlets.NewCimInstanceCommand", StringComparison.Ordinal)) - { - string pseudoboundNamespace = CompletionCompleters.NativeCommandArgumentCompletion_ExtractSecondaryArgument(pseudoBinding.BoundArguments, "Namespace").FirstOrDefault(); - string pseudoboundClassName = CompletionCompleters.NativeCommandArgumentCompletion_ExtractSecondaryArgument(pseudoBinding.BoundArguments, "ClassName").FirstOrDefault(); - if (!string.IsNullOrWhiteSpace(pseudoboundClassName)) - { - yield return new PSTypeName(string.Format( - CultureInfo.InvariantCulture, - "{0}#{1}/{2}", - typeof(Microsoft.Management.Infrastructure.CimInstance).FullName, - pseudoboundNamespace ?? "root/cimv2", - pseudoboundClassName)); - } - yield return new PSTypeName(typeof(Microsoft.Management.Infrastructure.CimInstance)); - yield break; - } - - // where-object - yields whatever we saw before where-object in the pipeline. - // same for sort-object - if (cmdletInfo.ImplementingType == typeof(WhereObjectCommand) - || cmdletInfo.ImplementingType.FullName.Equals("Microsoft.PowerShell.Commands.SortObjectCommand", StringComparison.Ordinal)) - { - var parentPipeline = this.Parent as PipelineAst; - if (parentPipeline != null) - { - int i; - for (i = 0; i < parentPipeline.PipelineElements.Count; i++) - { - if (parentPipeline.PipelineElements[i] == this) - break; - } - if (i > 0) - { - foreach (var typename in parentPipeline.PipelineElements[i - 1].GetInferredType(context)) - { - yield return typename; - } - } - } - - // We could also check -InputObject, but that is rarely used. But don't bother continuing. - yield break; - } - - // foreach-object - yields the type of it's script block parameters - if (cmdletInfo.ImplementingType == typeof(ForEachObjectCommand)) - { - AstParameterArgumentPair argument; - if (pseudoBinding.BoundArguments.TryGetValue("Begin", out argument)) - { - foreach (var type in GetInferredTypeFromScriptBlockParameter(argument, context)) - { - yield return type; - } - } - - if (pseudoBinding.BoundArguments.TryGetValue("Process", out argument)) - { - foreach (var type in GetInferredTypeFromScriptBlockParameter(argument, context)) - { - yield return type; - } - } - - if (pseudoBinding.BoundArguments.TryGetValue("End", out argument)) - { - foreach (var type in GetInferredTypeFromScriptBlockParameter(argument, context)) - { - yield return type; - } - } - } - } - - // The OutputType property ignores the parameter set specified in the OutputTypeAttribute. - // With psuedo-binding, we actually know the candidate parameter sets, so we could take - // advantage of it here, but I opted for the simpler code because so few cmdlets use - // ParameterSetName in OutputType and of the ones I know about, it isn't that useful. - foreach (var result in commandInfo.OutputType) - { - yield return result; - } - } - - private IEnumerable GetInferredTypeFromScriptBlockParameter(AstParameterArgumentPair argument, CompletionContext context) - { - var argumentPair = argument as AstPair; - if (argumentPair != null && argumentPair.Argument is ScriptBlockExpressionAst) - { - var scriptBlockExpressionAst = (ScriptBlockExpressionAst)argumentPair.Argument; - foreach (var type in scriptBlockExpressionAst.ScriptBlock.GetInferredType(context)) - { - yield return type; - } - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -6055,11 +5677,6 @@ public override Ast Copy() return new CommandExpressionAst(this.Extent, newExpression, newRedirections); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Expression.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -6111,11 +5728,6 @@ protected RedirectionAst(IScriptExtent extent, RedirectionStream from) /// The stream to read objects from. Objects are either merged with another stream, or written to a file. /// public RedirectionStream FromStream { get; private set; } - - internal override IEnumerable GetInferredType(CompletionContext context) - { - return EmptyPSTypeNameArray; - } } /// @@ -6368,11 +5980,6 @@ public override Ast Copy() return new AssignmentStatementAst(this.Extent, newLeft, this.Operator, newRight, this.ErrorPosition); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Left.GetInferredType(context); - } - /// /// Return all of the expressions assigned by the assignment statement. Typically /// it's just a variable expression, but if is an , @@ -6507,17 +6114,6 @@ public override Ast Copy() }; } - - /// - /// Gets inferred type - /// - /// - /// - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Body.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -7006,17 +6602,6 @@ internal override AstVisitAction InternalVisit(AstVisitor visitor) #endregion Visitors - /// - /// Get inferred type of DynamicKeywordStatementAst - /// - /// - /// - internal override IEnumerable GetInferredType(CompletionContext context) - { - // TODO: What is the right InferredType for the AST - return CommandElements[0].GetInferredType(context); - } - #region Internal Properties/Methods internal DynamicKeyword Keyword @@ -7417,20 +7002,6 @@ public override Type StaticType } internal static readonly PSTypeName[] BoolTypeNameArray = new PSTypeName[] { new PSTypeName(typeof(bool)) }; - internal override IEnumerable GetInferredType(CompletionContext context) - { - switch (Operator) - { - case TokenKind.Xor: - case TokenKind.And: - case TokenKind.Or: - case TokenKind.Is: - return BoolTypeNameArray; - } - - // This may not be right, but we're guessing anyway, and it's not a bad guess. - return Left.GetInferredType(context); - } #region Visitors @@ -7521,13 +7092,6 @@ public override Type StaticType } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return (TokenKind == TokenKind.Not || TokenKind == TokenKind.Exclaim) - ? BinaryExpressionAst.BoolTypeNameArray - : Child.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -7597,11 +7161,6 @@ public override Ast Copy() return new BlockStatementAst(this.Extent, this.Kind, newBody); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Body.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -7673,11 +7232,6 @@ public override Ast Copy() return new AttributedExpressionAst(this.Extent, newAttribute, newChild); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Child.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -7801,19 +7355,6 @@ public override Type StaticType get { return this.Type.TypeName.GetReflectionType() ?? typeof(object); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - var type = this.Type.TypeName.GetReflectionType(); - if (type != null) - { - yield return new PSTypeName(type); - } - else - { - yield return new PSTypeName(this.Type.TypeName.FullName); - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -7919,208 +7460,6 @@ public override Ast Copy() return new MemberExpressionAst(this.Extent, newExpression, newMember, this.Static); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - // If the member name isn't simple, don't even try. - var memberAsStringConst = Member as StringConstantExpressionAst; - if (memberAsStringConst == null) - yield break; - - PSTypeName[] exprType; - if (this.Static) - { - var exprAsType = Expression as TypeExpressionAst; - if (exprAsType == null) - yield break; - var type = exprAsType.TypeName.GetReflectionType(); - if (type == null) - { - var typeName = exprAsType.TypeName as TypeName; - if (typeName == null || typeName._typeDefinitionAst == null) - yield break; - - exprType = new[] { new PSTypeName(typeName._typeDefinitionAst) }; - } - else - { - exprType = new[] { new PSTypeName(type) }; - } - } - else - { - exprType = Expression.GetInferredType(context).ToArray(); - if (exprType.Length == 0) - yield break; - } - - var maybeWantDefaultCtor = this.Static - && this is InvokeMemberExpressionAst - && memberAsStringConst.Value.Equals("new", StringComparison.OrdinalIgnoreCase); - - // We use a list of member names because we might discover aliases properties - // and if we do, we'll add to the list. - var memberNameList = new List { memberAsStringConst.Value }; - foreach (var type in exprType) - { - var members = CompletionCompleters.GetMembersByInferredType(type, context, this.Static, filter: null); - - for (int i = 0; i < memberNameList.Count; i++) - { - string memberName = memberNameList[i]; - foreach (var member in members) - { - var propertyInfo = member as PropertyInfo; - if (propertyInfo != null) - { - if (propertyInfo.Name.Equals(memberName, StringComparison.OrdinalIgnoreCase) && - !(this is InvokeMemberExpressionAst)) - { - yield return new PSTypeName(propertyInfo.PropertyType); - break; - } - continue; - } - - var fieldInfo = member as FieldInfo; - if (fieldInfo != null) - { - if (fieldInfo.Name.Equals(memberName, StringComparison.OrdinalIgnoreCase) && - !(this is InvokeMemberExpressionAst)) - { - yield return new PSTypeName(fieldInfo.FieldType); - break; - } - continue; - } - - var methodCacheEntry = member as DotNetAdapter.MethodCacheEntry; - if (methodCacheEntry != null) - { - if (methodCacheEntry[0].method.Name.Equals(memberName, StringComparison.OrdinalIgnoreCase)) - { - maybeWantDefaultCtor = false; - if (this is InvokeMemberExpressionAst) - { - foreach (var method in methodCacheEntry.methodInformationStructures) - { - var methodInfo = method.method as MethodInfo; - if (methodInfo != null && !methodInfo.ReturnType.GetTypeInfo().ContainsGenericParameters) - { - yield return new PSTypeName(methodInfo.ReturnType); - } - } - } - else - { - // Accessing a method as a property, we'd return a wrapper over the method. - yield return new PSTypeName(typeof(PSMethod)); - } - break; - } - continue; - } - - var memberAst = member as MemberAst; - if (memberAst != null) - { - if (memberAst.Name.Equals(memberName, StringComparison.OrdinalIgnoreCase)) - { - if (this is InvokeMemberExpressionAst) - { - var functionMemberAst = memberAst as FunctionMemberAst; - if (functionMemberAst != null && !functionMemberAst.IsReturnTypeVoid()) - { - yield return new PSTypeName(functionMemberAst.ReturnType.TypeName); - } - } - else - { - var propertyMemberAst = memberAst as PropertyMemberAst; - if (propertyMemberAst != null) - { - if (propertyMemberAst.PropertyType != null) - { - yield return new PSTypeName(propertyMemberAst.PropertyType.TypeName); - } - else - { - yield return new PSTypeName(typeof(object)); - } - } - else - { - // Accessing a method as a property, we'd return a wrapper over the method. - yield return new PSTypeName(typeof(PSMethod)); - } - } - } - continue; - } - - var memberInfo = member as PSMemberInfo; - if (memberInfo == null || - !memberInfo.Name.Equals(memberName, StringComparison.OrdinalIgnoreCase)) - { - continue; - } - - var noteProperty = member as PSNoteProperty; - if (noteProperty != null) - { - yield return new PSTypeName(noteProperty.Value.GetType()); - break; - } - - var aliasProperty = member as PSAliasProperty; - if (aliasProperty != null) - { - memberNameList.Add(aliasProperty.ReferencedMemberName); - break; - } - - var codeProperty = member as PSCodeProperty; - if (codeProperty != null) - { - if (codeProperty.GetterCodeReference != null) - { - yield return new PSTypeName(codeProperty.GetterCodeReference.ReturnType); - } - break; - } - - ScriptBlock scriptBlock = null; - var scriptProperty = member as PSScriptProperty; - if (scriptProperty != null) - { - scriptBlock = scriptProperty.GetterScript; - } - - var scriptMethod = member as PSScriptMethod; - if (scriptMethod != null) - { - scriptBlock = scriptMethod.Script; - } - - if (scriptBlock != null) - { - foreach (var t in scriptBlock.OutputType) - { - yield return t; - } - } - - break; - } - } - - // We didn't find any constructors but they used [T]::new() syntax - if (maybeWantDefaultCtor) - { - yield return type; - } - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -9239,11 +8578,6 @@ public override Ast Copy() /// public override Type StaticType { get { return typeof(Type); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - yield return new PSTypeName(StaticType); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -9358,205 +8692,6 @@ public override Ast Copy() return new VariableExpressionAst(this.Extent, this.VariablePath, this.Splatted); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - // We don't need to handle drive qualified variables, we can usually get those values - // without needing to "guess" at the type. - if (!VariablePath.IsVariable) - { - // Not a variable - the caller should have already tried going to session state - // to get the item and hence it's type, but that must have failed. Don't try again. - yield break; - } - - Ast parent = this.Parent; - if (VariablePath.IsUnqualified && - (SpecialVariables.IsUnderbar(VariablePath.UserPath) - || VariablePath.UserPath.Equals(SpecialVariables.PSItem, StringComparison.OrdinalIgnoreCase))) - { - // $_ is special, see if we're used in a script block in some pipeline. - while (parent != null) - { - if (parent is ScriptBlockExpressionAst) - break; - parent = parent.Parent; - } - - if (parent != null) - { - if (parent.Parent is CommandExpressionAst && parent.Parent.Parent is PipelineAst) - { - // Script block in a hash table, could be something like: - // dir | ft @{ Expression = { $_ } } - if (parent.Parent.Parent.Parent is HashtableAst) - { - parent = parent.Parent.Parent.Parent; - } - else if (parent.Parent.Parent.Parent is ArrayLiteralAst && parent.Parent.Parent.Parent.Parent is HashtableAst) - { - parent = parent.Parent.Parent.Parent.Parent; - } - } - if (parent.Parent is CommandParameterAst) - { - parent = parent.Parent; - } - - var commandAst = parent.Parent as CommandAst; - if (commandAst != null) - { - // We found a command, see if there is a previous command in the pipeline. - PipelineAst pipelineAst = (PipelineAst)commandAst.Parent; - var previousCommandIndex = pipelineAst.PipelineElements.IndexOf(commandAst) - 1; - if (previousCommandIndex >= 0) - { - foreach (var result in pipelineAst.PipelineElements[0].GetInferredType(context)) - { - if (result.Type != null) - { - // Assume (because we're looking at $_ and we're inside a script block that is an - // argument to some command) that the type we're getting is actually unrolled. - // This might not be right in all cases, but with our simple analysis, it's - // right more often than it's wrong. - if (result.Type.IsArray) - { - yield return new PSTypeName(result.Type.GetElementType()); - continue; - } - - if (typeof(IEnumerable).IsAssignableFrom(result.Type)) - { - // We can't deduce much from IEnumerable, but we can if it's generic. - var enumerableInterfaces = result.Type.GetInterfaces().Where( - t => - t.GetTypeInfo().IsGenericType && - t.GetGenericTypeDefinition() == typeof(IEnumerable<>)); - foreach (var i in enumerableInterfaces) - { - yield return new PSTypeName(i.GetGenericArguments()[0]); - } - continue; - } - } - yield return result; - } - } - yield break; - } - } - } - - // For certain variables, we always know their type, well at least we can assume we know. - if (VariablePath.IsUnqualified) - { - if (VariablePath.UserPath.Equals(SpecialVariables.This, StringComparison.OrdinalIgnoreCase) && - context.CurrentTypeDefinitionAst != null) - { - yield return new PSTypeName(context.CurrentTypeDefinitionAst); - yield break; - } - - for (int i = 0; i < SpecialVariables.AutomaticVariables.Length; i++) - { - if (VariablePath.UserPath.Equals(SpecialVariables.AutomaticVariables[i], StringComparison.OrdinalIgnoreCase)) - { - var type = SpecialVariables.AutomaticVariableTypes[i]; - if (!(type == typeof(object))) - yield return new PSTypeName(type); - break; - } - } - } - - // Look for our variable as a parameter or on the lhs of an assignment - hopefully we'll find either - // a type constraint or at least we can use the rhs to infer the type. - - while (parent.Parent != null) - { - parent = parent.Parent; - } - - if (parent.Parent is FunctionDefinitionAst) - { - parent = parent.Parent; - } - - int startOffset = this.Extent.StartOffset; - var targetAsts = AstSearcher.FindAll(parent, - ast => (ast is ParameterAst || ast is AssignmentStatementAst || ast is ForEachStatementAst || ast is CommandAst) - && AstAssignsToSameVariable(ast) - && ast.Extent.EndOffset < startOffset, - searchNestedScriptBlocks: true); - - var parameterAst = targetAsts.OfType().FirstOrDefault(); - if (parameterAst != null) - { - var parameterTypes = parameterAst.GetInferredType(context).ToArray(); - if (parameterTypes.Length > 0) - { - foreach (var parameterType in parameterTypes) - { - yield return parameterType; - } - yield break; - } - } - - var assignAsts = targetAsts.OfType().ToArray(); - - // If any of the assignments lhs use a type constraint, then we use that. - // Otherwise, we use the rhs of the "nearest" assignment - foreach (var assignAst in assignAsts) - { - var lhsConvert = assignAst.Left as ConvertExpressionAst; - if (lhsConvert != null) - { - yield return new PSTypeName(lhsConvert.Type.TypeName); - yield break; - } - } - - var foreachAst = targetAsts.OfType().FirstOrDefault(); - if (foreachAst != null) - { - foreach (var typeName in foreachAst.Condition.GetInferredType(context)) - { - yield return typeName; - } - yield break; - } - - var commandCompletionAst = targetAsts.OfType().FirstOrDefault(); - if (commandCompletionAst != null) - { - foreach (var typeName in commandCompletionAst.GetInferredType(context)) - { - yield return typeName; - } - yield break; - } - - int smallestDiff = int.MaxValue; - AssignmentStatementAst closestAssignment = null; - foreach (var assignAst in assignAsts) - { - var endOffset = assignAst.Extent.EndOffset; - if ((startOffset - endOffset) < smallestDiff) - { - smallestDiff = startOffset - endOffset; - closestAssignment = assignAst; - } - } - - if (closestAssignment != null) - { - foreach (var type in closestAssignment.Right.GetInferredType(context)) - { - yield return type; - } - } - } - internal bool IsSafeVariableReference(HashSet validVariables, ref bool usesParameter) { bool ok = false; @@ -9580,71 +8715,6 @@ internal bool IsSafeVariableReference(HashSet validVariables, ref bool u return ok; } - private bool AstAssignsToSameVariable(Ast ast) - { - var parameterAst = ast as ParameterAst; - if (parameterAst != null) - { - return VariablePath.IsUnscopedVariable && - parameterAst.Name.VariablePath.UnqualifiedPath.Equals(VariablePath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase); - } - - var foreachAst = ast as ForEachStatementAst; - if (foreachAst != null) - { - return VariablePath.IsUnscopedVariable && - foreachAst.Variable.VariablePath.UnqualifiedPath.Equals(VariablePath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase); - } - - var commandAst = ast as CommandAst; - if (commandAst != null) - { - string[] variableParameters = new string[] { "PV", "PipelineVariable", "OV", "OutVariable" }; - StaticBindingResult bindingResult = StaticParameterBinder.BindCommand(commandAst, false, variableParameters); - - if (bindingResult != null) - { - ParameterBindingResult parameterBindingResult; - - foreach (string commandVariableParameter in variableParameters) - { - if (bindingResult.BoundParameters.TryGetValue(commandVariableParameter, out parameterBindingResult)) - { - if (String.Equals(VariablePath.UnqualifiedPath, (string)parameterBindingResult.ConstantValue, StringComparison.OrdinalIgnoreCase)) - { - return true; - } - } - } - } - - return false; - } - - var assignmentAst = (AssignmentStatementAst)ast; - var lhs = assignmentAst.Left; - var convertExpr = lhs as ConvertExpressionAst; - if (convertExpr != null) - { - lhs = convertExpr.Child; - } - - var varExpr = lhs as VariableExpressionAst; - if (varExpr == null) - return false; - - var candidateVarPath = varExpr.VariablePath; - if (candidateVarPath.UserPath.Equals(VariablePath.UserPath, StringComparison.OrdinalIgnoreCase)) - return true; - - // The following condition is making an assumption that at script scope, we didn't use $script:, but in the local scope, we did - // If we are searching anything other than script scope, this is wrong. - if (VariablePath.IsScript && VariablePath.UnqualifiedPath.Equals(candidateVarPath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase)) - return true; - - return false; - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -9794,14 +8864,6 @@ public override Type StaticType get { return Value != null ? Value.GetType() : typeof(object); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - if (Value != null) - { - yield return new PSTypeName(Value.GetType()); - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10075,11 +9137,6 @@ public override Type StaticType /// internal string FormatExpression { get; private set; } - internal override IEnumerable GetInferredType(CompletionContext context) - { - yield return new PSTypeName(typeof(string)); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10154,11 +9211,6 @@ public override Type StaticType get { return typeof(ScriptBlock); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - yield return new PSTypeName(typeof(ScriptBlock)); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10228,11 +9280,6 @@ public override Ast Copy() /// public override Type StaticType { get { return typeof(object[]); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - yield return new PSTypeName(typeof(object[])); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10327,11 +9374,6 @@ public override Ast Copy() /// public override Type StaticType { get { return typeof(Hashtable); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - yield return new PSTypeName(typeof(Hashtable)); - } - // Indicates that this ast was constructed as part of a schematized object instead of just a plain hash literal. internal bool IsSchemaElement { get; set; } @@ -10411,11 +9453,6 @@ public override Ast Copy() /// public override Type StaticType { get { return typeof(object[]); } } - internal override IEnumerable GetInferredType(CompletionContext context) - { - yield return new PSTypeName(typeof(object[])); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10477,11 +9514,6 @@ public override Ast Copy() return new ParenExpressionAst(this.Extent, newPipeline); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return Pipeline.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10548,11 +9580,6 @@ public override Ast Copy() return new SubExpressionAst(this.Extent, newStatementBlock); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return SubExpression.GetInferredType(context); - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) @@ -10623,11 +9650,6 @@ public override Ast Copy() return newUsingExpression; } - internal override IEnumerable GetInferredType(CompletionContext context) - { - return SubExpression.GetInferredType(context); - } - #region UsingExpression Utilities internal const string UsingPrefix = "__using_"; @@ -10771,53 +9793,6 @@ public override Ast Copy() return new IndexExpressionAst(this.Extent, newTarget, newIndex); } - internal override IEnumerable GetInferredType(CompletionContext context) - { - var targetTypes = Target.GetInferredType(context); - foreach (var psType in targetTypes) - { - var type = psType.Type; - if (type != null) - { - if (type.IsArray) - { - yield return new PSTypeName(type.GetElementType()); - continue; - } - - foreach (var i in type.GetInterfaces()) - { - if (i.GetTypeInfo().IsGenericType && i.GetGenericTypeDefinition() == typeof(IDictionary<,>)) - { - var valueType = i.GetGenericArguments()[1]; - if (!valueType.GetTypeInfo().ContainsGenericParameters) - { - yield return new PSTypeName(valueType); - } - } - } - - var defaultMember = type.GetCustomAttributes(true).FirstOrDefault(); - if (defaultMember != null) - { - var indexers = - type.GetMethods(BindingFlags.Public | BindingFlags.Instance).Where( - m => m.Name.Equals("get_" + defaultMember.MemberName)); - foreach (var indexer in indexers) - { - yield return new PSTypeName(indexer.ReturnType); - } - } - } - - // Inferred type of target wasn't indexable. Assume (perhaps incorrectly) - // that it came from OutputType and that more than one object was returned - // and that we're indexing because of that, in which case, OutputType really - // is the inferred type. - yield return psType; - } - } - #region Visitors internal override object Accept(ICustomAstVisitor visitor) diff --git a/src/System.Management.Automation/engine/CommandCompletion/CompletionExecutionHelper.cs b/src/System.Management.Automation/utils/PowerShellExecutionHelper.cs similarity index 74% rename from src/System.Management.Automation/engine/CommandCompletion/CompletionExecutionHelper.cs rename to src/System.Management.Automation/utils/PowerShellExecutionHelper.cs index 90ef923e152..93ca34b8b71 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/CompletionExecutionHelper.cs +++ b/src/System.Management.Automation/utils/PowerShellExecutionHelper.cs @@ -1,34 +1,28 @@ - -/********************************************************************++ +/********************************************************************++ Copyright (c) Microsoft Corporation. All rights reserved. --********************************************************************/ +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Management.Automation.Runspaces; + namespace System.Management.Automation { - using System; - using System.Collections.ObjectModel; - using System.Collections.Generic; - using System.Management.Automation.Runspaces; - using System.Collections; - - /// - /// Auxiliary class to the execution of commands as needed by - /// CommandCompletion - /// - internal class CompletionExecutionHelper + internal class PowerShellExecutionHelper { #region Constructors - // Creates a new CompletionExecutionHelper with the PowerShell instance that will be used to execute the tab expansion commands + // Creates a new PowerShellExecutionHelper with the PowerShell instance that will be used to execute the tab expansion commands // Used by the ISE - internal CompletionExecutionHelper(PowerShell powershell) + internal PowerShellExecutionHelper(PowerShell powershell) { if (powershell == null) { throw PSTraceSource.NewArgumentNullException("powershell"); } - this.CurrentPowerShell = powershell; + CurrentPowerShell = powershell; } #endregion Constructors @@ -44,16 +38,10 @@ internal CompletionExecutionHelper(PowerShell powershell) internal PowerShell CurrentPowerShell { get; set; } // Returns true if this instance is currently executing a command - internal bool IsRunning - { - get { return CurrentPowerShell.InvocationStateInfo.State == PSInvocationState.Running; } - } + internal bool IsRunning => CurrentPowerShell.InvocationStateInfo.State == PSInvocationState.Running; // Returns true if the command executed by this instance was stopped - internal bool IsStopped - { - get { return CurrentPowerShell.InvocationStateInfo.State == PSInvocationState.Stopped; } - } + internal bool IsStopped => CurrentPowerShell.InvocationStateInfo.State == PSInvocationState.Stopped; #endregion Fields and Properties @@ -62,7 +50,7 @@ internal bool IsStopped internal Collection ExecuteCommand(string command) { Exception unused; - return this.ExecuteCommand(command, true, out unused, null); + return ExecuteCommand(command, true, out unused, null); } internal bool ExecuteCommandAndGetResultAsBool() @@ -106,7 +94,7 @@ internal Collection ExecuteCommand(string command, bool isScript, out exceptionThrown = null; // This flag indicates a previous call to this method had its pipeline cancelled - if (this.CancelTabCompletion) + if (CancelTabCompletion) { return new Collection(); } @@ -130,10 +118,10 @@ internal Collection ExecuteCommand(string command, bool isScript, out // If this pipeline has been stopped lets set a flag to cancel all future tab completion calls // untill the next completion - if (this.IsStopped) + if (IsStopped) { results = new Collection(); - this.CancelTabCompletion = true; + CancelTabCompletion = true; } } catch (Exception e) @@ -149,7 +137,7 @@ internal Collection ExecuteCurrentPowerShell(out Exception exceptionTh exceptionThrown = null; // This flag indicates a previous call to this method had its pipeline cancelled - if (this.CancelTabCompletion) + if (CancelTabCompletion) { return new Collection(); } @@ -161,10 +149,10 @@ internal Collection ExecuteCurrentPowerShell(out Exception exceptionTh // If this pipeline has been stopped lets set a flag to cancel all future tab completion calls // untill the next completion - if (this.IsStopped) + if (IsStopped) { results = new Collection(); - this.CancelTabCompletion = true; + CancelTabCompletion = true; } } catch (Exception e) @@ -236,4 +224,39 @@ internal static void SafeAddToStringList(List list, object obj) #endregion Helpers } -} \ No newline at end of file + + internal static class PowerShellExtensionHelpers + { + internal static PowerShell AddCommandWithPreferenceSetting(this PowerShellExecutionHelper helper, + string command, Type type = null) + { + return helper.CurrentPowerShell.AddCommandWithPreferenceSetting(command, type); + } + + internal static PowerShell AddCommandWithPreferenceSetting(this PowerShell powershell, string command, Type type = null) + { + Diagnostics.Assert(powershell != null, "the passed-in powershell cannot be null"); + Diagnostics.Assert(!String.IsNullOrWhiteSpace(command), + "the passed-in command name should not be null or whitespaces"); + + if (type != null) + { + var cmdletInfo = new CmdletInfo(command, type); + + powershell.AddCommand(cmdletInfo); + } + else + { + powershell.AddCommand(command); + } + powershell + .AddParameter("ErrorAction", ActionPreference.Ignore) + .AddParameter("WarningAction", ActionPreference.Ignore) + .AddParameter("InformationAction", ActionPreference.Ignore) + .AddParameter("Verbose", false) + .AddParameter("Debug", false); + + return powershell; + } + } +} diff --git a/test/powershell/engine/Api/TypeInference.Tests.ps1 b/test/powershell/engine/Api/TypeInference.Tests.ps1 new file mode 100644 index 00000000000..65e9834ae5c --- /dev/null +++ b/test/powershell/engine/Api/TypeInference.Tests.ps1 @@ -0,0 +1,859 @@ +using namespace System.Management.Automation +using namespace System.Collections.Generic + +Describe "Type inference Tests" -tags "CI" { + BeforeAll { + $ati = [Cmdlet].Assembly.GetType("System.Management.Automation.AstTypeInference") + $inferType = $ati.GetMethods().Where{$_.Name -ceq "InferTypeOf"} + $m1 = 'System.Collections.Generic.IList`1[System.Management.Automation.PSTypeName] InferTypeOf(System.Management.Automation.Language.Ast)' + $m2 = 'System.Collections.Generic.IList`1[System.Management.Automation.PSTypeName] InferTypeOf(System.Management.Automation.Language.Ast, System.Management.Automation.TypeInferenceRuntimePermissions)' + $m3 = 'System.Collections.Generic.IList`1[System.Management.Automation.PSTypeName] InferTypeOf(System.Management.Automation.Language.Ast, System.Management.Automation.PowerShell)' + $m4 = 'System.Collections.Generic.IList`1[System.Management.Automation.PSTypeName] InferTypeOf(System.Management.Automation.Language.Ast, System.Management.Automation.PowerShell, System.Management.Automation.TypeInferenceRuntimePermissions)' + + $inferTypeOf1 = $inferType.Where{$m1 -eq $_}[0] + $inferTypeOf2 = $inferType.Where{$m2 -eq $_}[0] + $inferTypeOf3 = $inferType.Where{$m3 -eq $_}[0] + $inferTypeOf4 = $inferType.Where{$m4 -eq $_}[0] + + class AstTypeInference { + static [IList[PSTypeName]] InferTypeOf([Language.Ast] $ast) { + return $script:inferTypeOf1.Invoke($null, $ast) + } + + static [IList[PSTypeName]] InferTypeOf([Language.Ast] $ast, [System.Management.Automation.TypeInferenceRuntimePermissions] $runtimePermissions) { + return $script:inferTypeOf2.Invoke($null, @($ast, $runtimePermissions)) + } + + static [IList[PSTypeName]] InferTypeOf([Language.Ast] $ast, [System.Management.Automation.PowerShell] $powershell) { + return $script:inferTypeOf3.Invoke($null, @($ast, $powershell)) + } + + static [IList[PSTypeName]] InferTypeOf([Language.Ast] $ast, [PowerShell] $powerShell, [System.Management.Automation.TypeInferenceRuntimePermissions] $runtimePermissions) { + return $script:inferTypeOf4.Invoke($null, @($ast, $powerShell, $runtimePermissions)) + } + } + + } + + It "Infers type from integer" { + $res = [AstTypeInference]::InferTypeOf( { 1 }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + + It "Infers type from string literal" { + $res = [AstTypeInference]::InferTypeOf( { "Text" }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.String' + } + + It "Infers type from type expression" { + $res = [AstTypeInference]::InferTypeOf( { [int] }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Type' + } + + It "Infers type from hashtable" { + $res = [AstTypeInference]::InferTypeOf( { @{} }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Collections.Hashtable' + } + + It "Infers type from array expression" { + $res = [AstTypeInference]::InferTypeOf( { @() }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.object[]' + } + + It "Infers type from Array literal" { + $res = [AstTypeInference]::InferTypeOf( { , 1 }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.object[]' + } + + It "Infers type from array IndexExpresssion" { + $res = [AstTypeInference]::InferTypeOf( { (1, 2, 3)[0] }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.object' + } + + It "Infers type from generic container IndexExpression" { + $res = [AstTypeInference]::InferTypeOf( { + [System.Collections.Generic.List[int]]::new()[0] + }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + It 'Infers type of Index expression on Dictionary' { + $ast = { + [System.Collections.Generic.Dictionary[int, DateTime]]::new()[1] + }.ast.EndBlock.Statements[0].PipelineElements[0].Expression + $res = [AstTypeInference]::InferTypeOf( $ast ) + + $res.Count | Should be 1 + $res.Name | Should be System.DateTime + } + + It "Infers type from ScriptblockExpresssion" { + $res = [AstTypeInference]::InferTypeOf( { {} }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Management.Automation.Scriptblock' + } + + It "Infers type from paren expression" { + $res = [AstTypeInference]::InferTypeOf( { (1) }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + It "Infers type from expandable string expression" { + $res = [AstTypeInference]::InferTypeOf( { "$(1)" }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.String' + } + + It "Infers type from cast expression" { + $res = [AstTypeInference]::InferTypeOf( { [int] '1'}.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + It "Infers type from using namespace" { + $errors = $null + $tokens = $null + $ast = [Language.Parser]::ParseInput("using namespace System", [ref] $tokens, [ref] $errors) + $res = [AstTypeInference]::InferTypeOf( $ast.Find( {param($a) $a -is [System.Management.Automation.Language.UsingStatementAst] }, $true)) + $res.Count | Should Be 0 + } + + It "Infers type from unary expression" { + $res = [AstTypeInference]::InferTypeOf( { !$true }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Boolean' + } + + It "Infers type from param block" { + $res = [AstTypeInference]::InferTypeOf( { param() }.Ast) + $res.Count | Should Be 0 + } + + + It "Infers type from using statement" { + $res = [AstTypeInference]::InferTypeOf( { $pid = 1; $using:pid }.Ast.EndBlock.Statements[1].PipelineElements[0].Expression) + $res.Count | Should Be 1 + $res.Name | Should Be System.Int32 + } + + It "Infers type from param block" { + $res = [AstTypeInference]::InferTypeOf( { param([int] $i)}.Ast.ParamBlock) + $res.Count | Should Be 0 + } + + It "Infers type no type from Attribute" { + $res = [AstTypeInference]::InferTypeOf( { + [OutputType([int])] + param( + )}.Ast.ParamBlock.Attributes[0]) + $res.Count | Should Be 0 + } + + It "Infers type no type from named Attribute argument" { + $res = [AstTypeInference]::InferTypeOf( { + [OutputType(Type = [int])] + param( + )}.Ast.ParamBlock.Attributes[0].NamedArguments[0]) + $res.Count | Should Be 0 + } + + It "Infers type parameter types" { + $res = [AstTypeInference]::InferTypeOf( { + param([int] $i, [string] $s) + }.Ast.ParamBlock.Parameters[0]) + $res.Count | Should Be 1 + $res.Name | Should be System.Int32 + } + + It "Infers type parameter from PSTypeNameAttribute type" -Skip:(!$IsWindows) { + $res = [AstTypeInference]::InferTypeOf( { + param([int] $i, [PSTypeName('System.Management.ManagementObject#root\cimv2\Win32_Process')] $s) + }.Ast.ParamBlock.Parameters[1]) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Management.ManagementObject#root\cimv2\Win32_Process' + } + + It "Infers type from DATA statement" { + $res = [AstTypeInference]::InferTypeOf( { + DATA { + "text" + } + }.Ast.EndBlock) + $res.Count | Should Be 1 + $res.Name | Should be 'System.String' + } + + + It "Infers type from named block" { + $res = [AstTypeInference]::InferTypeOf( { begin {1}}.Ast.BeginBlock) + $res.Count | Should Be 1 + $res.Name | Should Be System.Int32 + } + + It "Infers type from function definition" { + $res = [AstTypeInference]::InferTypeOf( { + function foo { + return 1 + } + }.Ast.EndBlock) + $res.Count | Should Be 0 + } + + It "Infers type from convert expression" { + $errors = $null + $tokens = $null + $ast = [Language.Parser]::ParseInput('[int] "4"', [ref] $tokens, [ref] $errors) + $res = [AstTypeInference]::InferTypeOf( $ast.EndBlock.Statements[0]) + $res.Count | Should Be 1 + $res.Name | Should Be 'System.Int32' + } + + It "Infers type from type constraint" { + $errors = $null + $tokens = $null + $ast = [Language.Parser]::ParseInput('[int] $i', [ref] $tokens, [ref] $errors) + $res = [AstTypeInference]::InferTypeOf( $ast.EndBlock.Statements[0].PipelineElements[0].Expression.Attribute) + $res.Count | Should Be 0 + } + + It "Infers type from instance member property" { + $res = [AstTypeInference]::InferTypeOf( { 'Text'.Length }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + It "Infers type from static member property" { + $res = [AstTypeInference]::InferTypeOf( { [DateTime]::Now }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.DateTime' + } + + It "Infers type from instance member method" { + $res = [AstTypeInference]::InferTypeOf( { [int[]].GetElementType() }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Type' + } + + It "Infers type from integer * stringliteral" { + $res = [AstTypeInference]::InferTypeOf( { 5 * "5" }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + It "Infers type from string literal" { + $res = [AstTypeInference]::InferTypeOf( { "Text" }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.String' + } + + It "Infers type from stringliteral * integer" { + $res = [AstTypeInference]::InferTypeOf( { "5" * 2 }.Ast) + $res.Count | Should Be 1 + $res.Name | Should be 'System.String' + } + + It "Infers type from where-object of integer" { + $res = [AstTypeInference]::InferTypeOf( { [int[]] $i = 1..20; $i | Where-Object {$_ -gt 10} }.Ast) + foreach ($r in $res) { + $r.Name -in 'System.Int32', 'System.Int32[]' | Should be $true + } + } + + It "Infers type from foreach-object of integer" { + $res = [AstTypeInference]::InferTypeOf( { [int[]] $i = 1..20; $i | ForEach-Object {$_ * 10} }.Ast) + $res.Count | Should Be 2 + foreach ($r in $res) { + $r.Name -in 'System.Int32', 'System.Int32[]' | Should be $true + } + } + + It "Infers type from generic new" { + $res = [AstTypeInference]::InferTypeOf( { [System.Collections.Generic.List[int]]::new() }.Ast) + $res.Count | Should Be 1 + $res.Name | Should Match 'System.Collections.Generic.List`1\[\[System.Int32.*' + + } + + It "Infers type from cim command" -Skip:(!$IsWindows) { + $res = [AstTypeInference]::InferTypeOf( { Get-CimInstance -Namespace root/CIMV2 -ClassName Win32_Bios }.Ast) + $res.Count | Should Be 2 + + foreach ($r in $res) { + $r.Name -in 'Microsoft.Management.Infrastructure.CimInstance#root/CIMV2/Win32_Bios', + 'Microsoft.Management.Infrastructure.CimInstance' | Should be $true + } + } + + It "Infers type from foreach-object with begin/end" { + $res = [AstTypeInference]::InferTypeOf( { [int[]] $i = 1..20; $i | ForEach-Object -Begin {"Hi"} {$_ * 10} -End {[int]} }.Ast) + $res.Count | Should Be 4 + foreach ($r in $res) { + $r.Name -in 'System.Int32', 'System.Int32[]', 'System.String', 'System.Type' | Should be $true + } + } + + It "Infers type from OutputTypeAttribute" { + $res = [AstTypeInference]::InferTypeOf( { Get-Process -Id 2345 }.Ast) + $gpsOutput = [Microsoft.PowerShell.Commands.GetProcessCommand].GetCustomAttributes([System.Management.Automation.OutputTypeAttribute], $false).Type + $names = $gpsOutput.Name + foreach ($r in $res) { + $r.Name -in $names | Should Be $true + } + } + + It "Infers type from variable with AllowSafeEval" { + function Hide-GetProcess { Get-Process } + $p = Hide-GetProcess + $res = [AstTypeInference]::InferTypeOf( { $p }.Ast, [TypeInferenceRuntimePermissions]::AllowSafeEval) + $res.Name | Should Be 'System.Diagnostics.Process' + } + + It "Infers type from variable with type in scope" { + + $res = [AstTypeInference]::InferTypeOf( { + $p = 1 + $p + }.Ast) + $res.Name | Should Be 'System.Int32' + } + + It "Infers type from block statement" { + $errors = $null + $tokens = $null + $ast = [Language.Parser]::ParseInput("parallel {1}", [ref] $tokens, [ref] $errors) + + $res = [AstTypeInference]::InferTypeOf( $ast.EndBlock.Statements[0]) + $res.Name | Should Be 'System.Int32' + } + + It 'Infers type from attributed expession' { + $res = [AstTypeInference]::InferTypeOf( { + [ValidateRange(1, 2)] + [int]$i = 1 + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type from if statement' { + $res = [AstTypeInference]::InferTypeOf( { + if ($true) { return 1} + else { return 'Text'} + }.Ast) + + $res.Count | Should be 2 + foreach ($r in $res) { + $r.Name -in 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from switch statement' { + $res = [AstTypeInference]::InferTypeOf( { + switch (1, 2, 3) { + (1) { return 'Hello'} + (2) {return [int]} + default {return 1} + } + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from Foreach statement' { + $res = [AstTypeInference]::InferTypeOf( { + foreach ($i in 1, 2, 3) { + if ($i -eq 1) { return 'Hello'} + if ($i -eq 2) {return [int]} + return 1 + } + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from While statement' { + $res = [AstTypeInference]::InferTypeOf( { + while ($true) { + if ($i -eq 1) { return 'Hello'} + if ($i -eq 2) {return [int]} + return 1 + } + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from For statement' { + $res = [AstTypeInference]::InferTypeOf( { + for ($i = 0; $i -lt 10; $i++) { + if ($i -eq 1) { return 'Hello'} + if ($i -eq 2) {return [int]} + return 1 + } + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from Do-While statement' { + $res = [AstTypeInference]::InferTypeOf( { + do { + if ($i -eq 1) { return 'Hello'} + if ($i -eq 2) {return [int]} + return 1 + }while ($true) + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from Do-Until statement' { + $res = [AstTypeInference]::InferTypeOf( { + do { + if ($i -eq 1) { return 'Hello'} + if ($i -eq 2) {return [int]} + return 1 + } until ($true) + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from full scriptblock' { + $res = [AstTypeInference]::InferTypeOf( { + begin {1} + process {"text"} + end {[int]} + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Type', 'System.Int32', 'System.String' | Should be $true + } + } + + It 'Infers type from sub expression' { + $res = [AstTypeInference]::InferTypeOf( { + $(1) + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type from Throw statement' { + $res = [AstTypeInference]::InferTypeOf( { + throw 'Foo' + }.Ast) + + $res.Count | Should be 0 + } + + It 'Infers type from Return statement' { + $res = [AstTypeInference]::InferTypeOf( { + return 1 + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be 'System.Int32' + } + + It 'Infers type from New-Object statement' { + $res = [AstTypeInference]::InferTypeOf( { + New-Object -TypeName 'System.Diagnostics.Stopwatch' + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be 'System.Diagnostics.Stopwatch' + } + + It 'Infers type from Continue statement' { + $res = [AstTypeInference]::InferTypeOf( { + continue + }.Ast) + + $res.Count | Should be 0 + } + + It 'Infers type from Break statement' { + $res = [AstTypeInference]::InferTypeOf( { + break + }.Ast) + + $res.Count | Should be 0 + } + + It 'Infers type from Merging redirection' { + $errors = $null + $tokens = $null + $ast = [Language.Parser]::ParseInput("p4 resolve ... 2>&1", [ref] $tokens, [ref] $errors) + $res = [AstTypeInference]::InferTypeOf( $ast.EndBlock.Statements[0].PipelineElements[0].Redirections[0] ) + $res.Count | Should be 0 + } + + It 'Infers type from File redirection' { + $errors = $null + $tokens = $null + $ast = [Language.Parser]::ParseInput("p4 resolve ... > foo.txt", [ref] $tokens, [ref] $errors) + $res = [AstTypeInference]::InferTypeOf( $ast.EndBlock.Statements[0].PipelineElements[0].Redirections[0] ) + $res.Count | Should be 0 + } + + + It 'Infers type of alias property' { + class X { + [int] $Length + } + Update-TypeData -Typename X -MemberType AliasProperty -MemberName AliasLength -Value Length -Force + $res = [AstTypeInference]::InferTypeOf( { + [x]::new().AliasLength + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + + It 'Infers type of code property' { + class X { + static [int] CodeProp([psobject] $o) { return 1 } + } + + class Y {} + Update-TypeData -TypeName Y -MemberName CodeProp -MemberType CodeProperty -Value ([X].GetMethod("CodeProp")) -Force + $res = [AstTypeInference]::InferTypeOf( { + [Y]::new().CodeProp + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of script property' { + class Y {} + Update-TypeData -TypeName Y -MemberName ScriptProp -MemberType ScriptProperty -Value {1} -Force + $res = [AstTypeInference]::InferTypeOf( { + [Y]::new().ScriptProp + }.Ast) + + $res.Count | Should be 0 + } + + It 'Infers type of script property with outputtype' { + class Y {} + Update-TypeData -TypeName Y -MemberName ScriptProp -MemberType ScriptProperty -Value {[OutputType([int])]param()1} -Force + $res = [AstTypeInference]::InferTypeOf( { + [Y]::new().ScriptProp + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of script method with outputtype' { + class Y {} + Update-TypeData -TypeName Y -MemberName MyScriptMethod -MemberType ScriptMethod -Value {[OutputType([int])]param()1} -Force + $res = [AstTypeInference]::InferTypeOf( { + [Y]::new().MyScriptMethod + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + + It 'Infers type of note property' { + + $res = [AstTypeInference]::InferTypeOf( { + [pscustomobject] @{ + A = '' + }.A + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be 'System.Management.Automation.PSObject' + } + + It 'Infers type of try catch finally' { + + $res = [AstTypeInference]::InferTypeOf( { + try { + 1 + } + catch [exception] { + "Text" + } + finally { + [int] + } + }.Ast) + + $res.Count | Should be 3 + foreach ($r in $res) { + $r.Name -in 'System.Int32', 'System.String', 'System.Type' | Should be $true + } + } + + It "Infers type from trap statement" { + $res = [AstTypeInference]::InferTypeOf( { + trap { + "text" + } + }.Ast.EndBlock.Traps[0]) + $res.Count | Should Be 1 + $res.Name | Should be 'System.String' + } + + It "Infers type from exit statement" { + $res = [AstTypeInference]::InferTypeOf( { + exit + }.Ast.EndBlock) + $res.Count | Should Be 0 + } + + It 'Infers type of Where/Sort/Foreach pipeline' { + $res = [AstTypeInference]::InferTypeOf( { + [int[]](1..10) | Sort-Object -Descending | Where-Object {$_ -gt 3} | ForEach-Object {$_.ToString()} + }.Ast) + + $res.Name | Should be System.String + } + + It 'Infers type of Method accessed as Property' { + $res = [AstTypeInference]::InferTypeOf( { + ''.ToString + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Management.Automation.PSMethod + } + + It 'Infers int from List[int] with foreach' { + $res = [AstTypeInference]::InferTypeOf( { + $l = [System.Collections.Generic.List[string]]::new() + $l | ForEach-Object {$_} + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.String + } + + It 'Infers class type' { + $res = [AstTypeInference]::InferTypeOf( { + class X { + [int] $I + [bool] Method() { + return $true + } + } + }.Ast) + + $res.Count | Should be 0 + } + + Context "TestDrivePath" { + BeforeAll { + $errors = $null + $tokens = $null + $p = Resolve-path TestDrive:/ + } + It 'Infers type of command parameter' { + $ast = [Language.Parser]::ParseInput("Get-ChildItem -Path $p/foo.txt", [ref] $tokens, [ref] $errors) + $cmdParam = $ast.EndBlock.Statements[0].Pipelineelements[0].CommandElements[1] + $res = [AstTypeInference]::InferTypeOf( $cmdParam ) + + $res.Count | Should be 0 + } + + It 'Infers type of command parameter - second form' { + $ast = [Language.Parser]::ParseInput("Get-ChildItem -LiteralPath $p/foo.txt", [ref] $tokens, [ref] $errors) + $cmdParam = $ast.EndBlock.Statements[0].Pipelineelements[0].CommandElements[1] + $res = [AstTypeInference]::InferTypeOf( $cmdParam ) + $res.Count | Should be 0 + } + + It 'Infers type of common commands with Path parameter' { + $ast = [Language.Parser]::ParseInput("Get-ChildItem -Path:$p/foo.txt", [ref] $tokens, [ref] $errors) + $cmdAst = $ast.EndBlock.Statements[0].Pipelineelements[0] + $res = [AstTypeInference]::InferTypeOf( $cmdAst ) + + $res.Count | Should be 2 + foreach ($r in $res) { + $r.Name -in 'System.IO.FileInfo', 'System.IO.DirectoryInfo' | Should be $true + } + } + + It 'Infers type of common commands with LiteralPath parameter' { + $ast = [Language.Parser]::ParseInput("Get-ChildItem -LiteralPath:$p/foo.txt", [ref] $tokens, [ref] $errors) + $cmdAst = $ast.EndBlock.Statements[0].Pipelineelements[0] + $res = [AstTypeInference]::InferTypeOf( $cmdAst ) + + $res.Count | Should be 2 + foreach ($r in $res) { + $r.Name -in 'System.IO.FileInfo', 'System.IO.DirectoryInfo' | Should be $true + } + } + } + + It 'Infers type of variable $_ in hashtable in command parameter' { + $variableAst = {1..10 | Format-table @{n = 'x'; ex = {$_}}}.ast.Find( {param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst]}, $true) + $res = [AstTypeInference]::InferTypeOf( $variableAst) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of variable $_ in hashtable from Array' { + $variableAst = { [int[]]::new(10) | Format-table @{n = 'x'; ex = {$_}}}.ast.Find( {param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst]}, $true) + $res = [AstTypeInference]::InferTypeOf( $variableAst) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of variable $_ in hashtable from generic IEnumerable ' { + $variableAst = { [System.Collections.Generic.List[int]]::new() | Format-table @{n = 'x'; ex = {$_}}}.ast.Find( {param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst]}, $true) + $res = [AstTypeInference]::InferTypeOf( $variableAst) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of variable $_ command parameter' { + $variableAst = { 1..10 | Group-Object {$_.Length}}.ast.Find( {param($a) $a -is [System.Management.Automation.Language.VariableExpressionAst]}, $true) + $res = [AstTypeInference]::InferTypeOf( $variableAst) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of function member' { + $res = [AstTypeInference]::InferTypeOf( { + class X { + [DateTime] GetDate() { return [datetime]::Now } + } + }.Ast.Find( {param($ast) $ast -is [System.Management.Automation.Language.FunctionMemberAst]}, $true)) + + $res.Count | Should be 0 + } + + It 'Infers type of MemberExpression on class property' { + class X { + [DateTime] $Date + } + $x = [X]::new() + $res = [AstTypeInference]::InferTypeOf( { + $x.Date + }.Ast.Find( {param($ast) $ast -is [System.Management.Automation.Language.MemberExpressionAst] -and $ast.Member.Value -eq 'Date'}, $true)) + + $res.Count | Should be 1 + $res.Name | Should Be System.DateTime + } + + It 'Infers type of MemberExpression on class Method' { + class X { + [DateTime] GetDate() { return [DateTime]::Now } + } + $x = [X]::new() + $res = [AstTypeInference]::InferTypeOf( { + $x.GetDate() + }.Ast.Find( {param($ast) $ast -is [System.Management.Automation.Language.MemberExpressionAst] -and $ast.Member.Value -eq 'GetDate'}, $true)) + + $res.Count | Should be 1 + $res.Name | Should Be System.DateTime + } + + + It 'Infers type of note property with safe eval' -Skip { + $res = [AstTypeInference]::InferTypeOf( { + [pscustomobject] @{ + A = '' + }.A + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be 'System.String' + } + + It 'Infers type of invoke operator scriptblock' -Skip { + $res = [AstTypeInference]::InferTypeOf( { + & {1} + }.Ast) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + + + It 'Infers type of script property with safe eval' -Skip { + class Y {} + Update-TypeData -TypeName Y -MemberName SafeEvalScriptProp -MemberType ScriptProperty -Value {1} -Force + $res = [AstTypeInference]::InferTypeOf( { + [Y]::new().SafeEvalScriptProp + }.Ast, [TypeInferenceRuntimePermissions]::AllowSafeEval) + + $res.Count | Should be 1 + $res.Name | Should be System.Int32 + } + + It 'Infers type of base ctor' -Skip { + $res = [AstTypeInference]::InferTypeOf( { + class BaseType { + [string] $s + BaseType([string]$s) {$this.s = $s} + } + class X : BaseType { + X() : base("foo") {} + } + }.Ast.Find( {param($ast) $ast -is [System.Management.Automation.Language.BaseCtorInvokeMemberExpressionAst]}, $true)) + + $res.Count | Should be BaseType + } +} + +Describe "AstTypeInference tests" -Tags CI { + It "Infers type from integer with passed in powershell instance" { + $powerShell = [PowerShell]::Create([RunspaceMode]::CurrentRunspace) + $res = [AstTypeInference]::InferTypeOf( { 1 }.Ast, $powerShell) + $res.Count | Should Be 1 + $res.Name | Should be 'System.Int32' + } + + It "Infers type from integer with passed in powershell instance and typeinferencespermissions" { + $powerShell = [PowerShell]::Create([RunspaceMode]::CurrentRunspace) + $v = 1 + $res = [AstTypeInference]::InferTypeOf( { $v }.Ast, $powerShell, [TypeInferenceRuntimePermissions]::AllowSafeEval) + $res.Name | Should be 'System.Int32' + } + +}