diff --git a/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java b/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java index d74e5cf63e..a817963145 100644 --- a/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java +++ b/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java @@ -89,10 +89,13 @@ public TraversalControl visitFragmentDefinition(FragmentDefinition node, Travers return TraversalControl.ABORT; } + QueryVisitorFragmentDefinitionEnvironment fragmentEnvironment = new QueryVisitorFragmentDefinitionEnvironmentImpl(node, context); + if (context.getVar(NodeTraverser.LeaveOrEnter.class) == LEAVE) { + postOrderCallback.visitFragmentDefinition(fragmentEnvironment); return TraversalControl.CONTINUE; } - + preOrderCallback.visitFragmentDefinition(fragmentEnvironment); QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(node.getTypeCondition().getName()); diff --git a/src/main/java/graphql/analysis/QueryVisitor.java b/src/main/java/graphql/analysis/QueryVisitor.java index 9a39344b54..be95bac4f1 100644 --- a/src/main/java/graphql/analysis/QueryVisitor.java +++ b/src/main/java/graphql/analysis/QueryVisitor.java @@ -16,4 +16,8 @@ public interface QueryVisitor { void visitFragmentSpread(QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment); + default void visitFragmentDefinition(QueryVisitorFragmentDefinitionEnvironment queryVisitorFragmentDefinitionEnvironment) { + + } + } diff --git a/src/main/java/graphql/analysis/QueryVisitorFragmentDefinitionEnvironment.java b/src/main/java/graphql/analysis/QueryVisitorFragmentDefinitionEnvironment.java new file mode 100644 index 0000000000..2d2b51e092 --- /dev/null +++ b/src/main/java/graphql/analysis/QueryVisitorFragmentDefinitionEnvironment.java @@ -0,0 +1,13 @@ +package graphql.analysis; + +import graphql.PublicApi; +import graphql.language.FragmentDefinition; +import graphql.language.Node; +import graphql.util.TraverserContext; + +@PublicApi +public interface QueryVisitorFragmentDefinitionEnvironment { + FragmentDefinition getFragmentDefinition(); + + TraverserContext getTraverserContext(); +} diff --git a/src/main/java/graphql/analysis/QueryVisitorFragmentDefinitionEnvironmentImpl.java b/src/main/java/graphql/analysis/QueryVisitorFragmentDefinitionEnvironmentImpl.java new file mode 100644 index 0000000000..da1349f7b9 --- /dev/null +++ b/src/main/java/graphql/analysis/QueryVisitorFragmentDefinitionEnvironmentImpl.java @@ -0,0 +1,56 @@ +package graphql.analysis; + +import graphql.Internal; +import graphql.language.FragmentDefinition; +import graphql.language.Node; +import graphql.util.TraverserContext; + +import java.util.Objects; + +@Internal +public class QueryVisitorFragmentDefinitionEnvironmentImpl implements QueryVisitorFragmentDefinitionEnvironment { + + private final FragmentDefinition fragmentDefinition; + private final TraverserContext traverserContext; + + + public QueryVisitorFragmentDefinitionEnvironmentImpl(FragmentDefinition fragmentDefinition, TraverserContext traverserContext) { + this.fragmentDefinition = fragmentDefinition; + this.traverserContext = traverserContext; + } + + @Override + public FragmentDefinition getFragmentDefinition() { + return fragmentDefinition; + } + + @Override + public TraverserContext getTraverserContext() { + return traverserContext; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + QueryVisitorFragmentDefinitionEnvironmentImpl that = (QueryVisitorFragmentDefinitionEnvironmentImpl) o; + return Objects.equals(fragmentDefinition, that.fragmentDefinition); + } + + @Override + public int hashCode() { + return Objects.hash(fragmentDefinition); + } + + @Override + public String toString() { + return "QueryVisitorFragmentDefinitionEnvironmentImpl{" + + "fragmentDefinition=" + fragmentDefinition + + '}'; + } +} + diff --git a/src/test/groovy/graphql/analysis/QueryTransformerTest.groovy b/src/test/groovy/graphql/analysis/QueryTransformerTest.groovy index 99afbffb41..1382c50966 100644 --- a/src/test/groovy/graphql/analysis/QueryTransformerTest.groovy +++ b/src/test/groovy/graphql/analysis/QueryTransformerTest.groovy @@ -5,6 +5,7 @@ import graphql.language.Document import graphql.language.Field import graphql.language.NodeUtil import graphql.language.SelectionSet +import graphql.language.TypeName import graphql.parser.Parser import graphql.schema.GraphQLSchema import spock.lang.Specification @@ -207,7 +208,7 @@ class QueryTransformerTest extends Specification { 0 * _ } - def "named fragment is traversed if it is a root and can be transformed"() { + def "fragment definition is traversed if it is a root and can be transformed"() { def query = TestUtil.parseQuery(''' { root { @@ -241,6 +242,15 @@ class QueryTransformerTest extends Specification { }) } } + + @Override + void visitFragmentDefinition(QueryVisitorFragmentDefinitionEnvironment env) { + def changed = env.fragmentDefinition.transform({ builder -> + builder.typeCondition(TypeName.newTypeName("newTypeName").build()) + .name("newFragName") + }) + changeNode(env.traverserContext, changed) + } } @@ -248,6 +258,6 @@ class QueryTransformerTest extends Specification { def newFragment = queryTransformer.transform(visitor) then: printAstCompact(newFragment) == - "fragment frag on Root {fooA {midA {newChild1 newChild2}}}" + "fragment newFragName on newTypeName {fooA {midA {newChild1 newChild2}}}" } } diff --git a/src/test/groovy/graphql/analysis/QueryTraversalTest.groovy b/src/test/groovy/graphql/analysis/QueryTraversalTest.groovy index e3aa19e183..608d115ec8 100644 --- a/src/test/groovy/graphql/analysis/QueryTraversalTest.groovy +++ b/src/test/groovy/graphql/analysis/QueryTraversalTest.groovy @@ -7,6 +7,7 @@ import graphql.language.FragmentDefinition import graphql.language.FragmentSpread import graphql.language.InlineFragment import graphql.language.NodeTraverser +import graphql.language.NodeUtil import graphql.parser.Parser import graphql.schema.GraphQLNonNull import graphql.schema.GraphQLObjectType @@ -302,6 +303,54 @@ class QueryTraversalTest extends Specification { } + + def "test preOrder and postOrder order for fragment definitions"() { + given: + def schema = TestUtil.schema(""" + type Query{ + foo: Foo + bar: String + } + type Foo { + subFoo: String + } + """) + def visitor = Mock(QueryVisitor) + def query = createQuery(""" + { + ...F1 + } + + fragment F1 on Query { + foo { + subFoo + } + } + """) + + def fragments = NodeUtil.getFragmentsByName(query) + + QueryTraversal queryTraversal = QueryTraversal.newQueryTraversal() + .schema(schema) + .root(fragments["F1"]) + .rootParentType(schema.getQueryType()) + .fragmentsByName(fragments) + .variables([:]) + .build() + + when: + queryTraversal.visitPreOrder(visitor) + + then: + 1 * visitor.visitFragmentDefinition({ QueryVisitorFragmentDefinitionEnvironment env -> env.fragmentDefinition == fragments["F1"] }) + + when: + queryTraversal.visitPostOrder(visitor) + + then: + 1 * visitor.visitFragmentDefinition({ QueryVisitorFragmentDefinitionEnvironment env -> env.fragmentDefinition == fragments["F1"] }) + } + def "works for mutations()"() { given: def schema = TestUtil.schema("""