diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java b/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java index 0d239e3591..fb5b7024fd 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java @@ -151,9 +151,7 @@ public void onCompleted(ExecutionResult result, Throwable t) { @Override public void onFieldValuesInfo(List fieldValueInfoList) { - synchronized (callStack) { - handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel); - } + handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel); } @Override @@ -171,16 +169,18 @@ public void onDeferredField(List field) { // thread safety : called with synchronised(callStack) // private void handleOnFieldValuesInfo(List fieldValueInfoList, CallStack callStack, int curLevel) { - callStack.increaseHappenedOnFieldValueCalls(curLevel); - int expectedStrategyCalls = 0; - for (FieldValueInfo fieldValueInfo : fieldValueInfoList) { - if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) { - expectedStrategyCalls++; - } else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) { - expectedStrategyCalls += getCountForList(fieldValueInfo); + synchronized (callStack) { + callStack.increaseHappenedOnFieldValueCalls(curLevel); + int expectedStrategyCalls = 0; + for (FieldValueInfo fieldValueInfo : fieldValueInfoList) { + if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) { + expectedStrategyCalls++; + } else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) { + expectedStrategyCalls += getCountForList(fieldValueInfo); + } } + callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls); } - callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls); dispatchIfNeeded(callStack, curLevel + 1); } @@ -215,9 +215,7 @@ public void onCompleted(ExecutionResult result, Throwable t) { @Override public void onFieldValueInfo(FieldValueInfo fieldValueInfo) { - synchronized (callStack) { - handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level); - } + handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level); } }; } @@ -232,8 +230,8 @@ public InstrumentationContext beginFieldFetch(InstrumentationFieldFetchP public void onDispatched(CompletableFuture result) { synchronized (callStack) { callStack.increaseFetchCount(level); - dispatchIfNeeded(callStack, level); } + dispatchIfNeeded(callStack, level); } @Override @@ -256,15 +254,17 @@ private void dispatchIfNeeded(CallStack callStack, int level) { // thread safety : called with synchronised(callStack) // private boolean levelReady(CallStack callStack, int level) { - if (level == 1) { - // level 1 is special: there is only one strategy call and that's it - return callStack.allFetchesHappened(1); - } - if (levelReady(callStack, level - 1) && callStack.allOnFieldCallsHappened(level - 1) - && callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) { - return true; + synchronized (callStack) { + if (level == 1) { + // level 1 is special: there is only one strategy call and that's it + return callStack.allFetchesHappened(1); + } + if (levelReady(callStack, level - 1) && callStack.allOnFieldCallsHappened(level - 1) + && callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) { + return true; + } + return false; } - return false; } void dispatch() { diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy new file mode 100644 index 0000000000..be69d6ca1a --- /dev/null +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy @@ -0,0 +1,218 @@ +package graphql.execution.instrumentation.dataloader + +import graphql.ExecutionInput +import graphql.ExecutionResult +import graphql.GraphQL +import graphql.TestUtil +import graphql.execution.Async +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.idl.RuntimeWiring +import org.apache.commons.lang3.concurrent.BasicThreadFactory +import org.dataloader.BatchLoader +import org.dataloader.DataLoader +import org.dataloader.DataLoaderOptions +import org.dataloader.DataLoaderRegistry +import spock.lang.Specification + +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionStage +import java.util.concurrent.SynchronousQueue +import java.util.concurrent.ThreadFactory +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit + +import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring + +class DataLoaderHangingTest extends Specification { + + public static final int NUM_OF_REPS = 50 + + def "deadlock attempt"() { + setup: + def sdl = """ + type Album { + id: ID! + title: String! + artist: Artist + songs( + limit: Int, + nextToken: String + ): ModelSongConnection + } + + type Artist { + id: ID! + name: String! + albums( + limit: Int, + nextToken: String + ): ModelAlbumConnection + songs( + limit: Int, + nextToken: String + ): ModelSongConnection + } + + type ModelAlbumConnection { + items: [Album] + nextToken: String + } + + type ModelArtistConnection { + items: [Artist] + nextToken: String + } + + type ModelSongConnection { + items: [Song] + nextToken: String + } + + type Query { + listArtists(limit: Int, nextToken: String): ModelArtistConnection + } + + type Song { + id: ID! + title: String! + artist: Artist + album: Album + } + """ + + ThreadFactory threadFactory = new BasicThreadFactory.Builder() + .namingPattern("resolver-chain-thread-%d").build() + def executor = new ThreadPoolExecutor(15, 15, 0L, + TimeUnit.MILLISECONDS, new SynchronousQueue<>(), threadFactory, + new ThreadPoolExecutor.CallerRunsPolicy()) + + def dataLoaderAlbums = new DataLoader(new BatchLoader>() { + @Override + CompletionStage>> load(List keys) { + return CompletableFuture.supplyAsync({ + def limit = keys.first().getArgument("limit") as Integer + return keys.collect({ k -> + def albums = [] + for (int i = 1; i <= limit; i++) { + albums.add(['id': "artist-$k.source.id-$i", 'title': "album-$i"]) + } + def albumsConnection = ['nextToken': 'album-next', 'items': albums] + return albumsConnection + }) + }, executor) + } + }, DataLoaderOptions.newOptions().setMaxBatchSize(5)) + + def dataLoaderSongs = new DataLoader(new BatchLoader>() { + @Override + CompletionStage>> load(List keys) { + return CompletableFuture.supplyAsync({ + def limit = keys.first().getArgument("limit") as Integer + return keys.collect({ k -> + def songs = [] + for (int i = 1; i <= limit; i++) { + songs.add(['id': "album-$k.source.id-$i", 'title': "song-$i"]) + } + def songsConnection = ['nextToken': 'song-next', 'items': songs] + return songsConnection + }) + }, executor) + } + }, DataLoaderOptions.newOptions().setMaxBatchSize(5)) + + def dataLoaderRegistry = new DataLoaderRegistry() + dataLoaderRegistry.register("artist.albums", dataLoaderAlbums) + dataLoaderRegistry.register("album.songs", dataLoaderSongs) + + def albumsDf = new MyForwardingDataFetcher(dataLoaderAlbums) + def songsDf = new MyForwardingDataFetcher(dataLoaderSongs) + + def dataFetcherArtists = new DataFetcher() { + @Override + Object get(DataFetchingEnvironment environment) { + def limit = environment.getArgument("limit") as Integer + def artists = [] + for (int i = 1; i <= limit; i++) { + artists.add(['id': "artist-$i", 'name': "artist-$i"]) + } + return ['nextToken': 'artist-next', 'items': artists] + } + } + + def wiring = RuntimeWiring.newRuntimeWiring() + .type(newTypeWiring("Query") + .dataFetcher("listArtists", dataFetcherArtists)) + .type(newTypeWiring("Artist") + .dataFetcher("albums", albumsDf)) + .type(newTypeWiring("Album") + .dataFetcher("songs", songsDf)) + .build() + + def schema = TestUtil.schema(sdl, wiring) + + when: + def graphql = GraphQL.newGraphQL(schema) + .instrumentation(new DataLoaderDispatcherInstrumentation(dataLoaderRegistry)) + .build() + + then: "execution shouldn't hang" + List> futures = [] + for (int i = 0; i < NUM_OF_REPS; i++) { + def result = graphql.executeAsync(ExecutionInput.newExecutionInput() + .query(""" + query getArtistsWithData { + listArtists(limit: 1) { + items { + name + albums(limit: 200) { + items { + title + # Uncommenting the following causes query to timeout + songs(limit: 5) { + nextToken + items { + title + } + } + } + } + } + } + } + """) + .build()) + result.whenComplete({ res, error -> + if (error) { + throw error + } + assert res.errors.empty + }) + // add all futures + futures.add(result) + } + // wait for each future to complete and grab the results + Async.each(futures) + .whenComplete({ results, error -> + if (error) { + throw error + } + results.each { assert it.errors.empty } + }) + .join() + } + + static class MyForwardingDataFetcher implements DataFetcher> { + + private final DataLoader dataLoader + + public MyForwardingDataFetcher(DataLoader dataLoader) { + this.dataLoader = dataLoader + } + + @Override + CompletableFuture get(DataFetchingEnvironment environment) { + return dataLoader.load(environment) + } + } +}