From 7a324f7e8e9c65776fbab894dc6cb1150eb5f3f4 Mon Sep 17 00:00:00 2001 From: Brad Baker Date: Wed, 3 Oct 2018 17:00:08 +1000 Subject: [PATCH] Alternative fix to dead lock challenge (#1255) * Added @Override as part of errorprone code health check * Revert "Added @Override as part of errorprone code health check" This reverts commit 38dfab1 * Brads attempt at https://github.com/graphql-java/graphql-java/pull/1234 * Missed the test --- .../FieldLevelTrackingApproach.java | 40 +++- .../dataloader/DataLoaderHangingTest.groovy | 219 ++++++++++++++++++ 2 files changed, 248 insertions(+), 11 deletions(-) create mode 100644 src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java b/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java index 12f9829eba..fd7090931f 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java @@ -94,13 +94,13 @@ public String toString() { '}'; } - public void dispatchIfNotDispatchedBefore(int level, Runnable dispatch) { + public boolean dispatchIfNotDispatchedBefore(int level) { if (dispatchedLevels.contains(level)) { Assert.assertShouldNeverHappen("level " + level + " already dispatched"); - return; + return false; } dispatchedLevels.add(level); - dispatch.run(); + return true; } public void clearAndMarkCurrentLevelAsReady(int level) { @@ -151,17 +151,25 @@ public void onCompleted(ExecutionResult result, Throwable t) { @Override public void onFieldValuesInfo(List fieldValueInfoList) { + boolean dispatchNeeded; synchronized (callStack) { - handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel); + dispatchNeeded = handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel); + } + if (dispatchNeeded) { + dispatch(); } } @Override public void onDeferredField(List field) { + boolean dispatchNeeded; // fake fetch count for this field synchronized (callStack) { callStack.increaseFetchCount(curLevel); - dispatchIfNeeded(callStack, curLevel); + dispatchNeeded = dispatchIfNeeded(callStack, curLevel); + } + if (dispatchNeeded) { + dispatch(); } } }; @@ -170,7 +178,7 @@ public void onDeferredField(List field) { // // thread safety : called with synchronised(callStack) // - private void handleOnFieldValuesInfo(List fieldValueInfoList, CallStack callStack, int curLevel) { + private boolean handleOnFieldValuesInfo(List fieldValueInfoList, CallStack callStack, int curLevel) { callStack.increaseHappenedOnFieldValueCalls(curLevel); int expectedStrategyCalls = 0; for (FieldValueInfo fieldValueInfo : fieldValueInfoList) { @@ -181,7 +189,7 @@ private void handleOnFieldValuesInfo(List fieldValueInfoList, Ca } } callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls); - dispatchIfNeeded(callStack, curLevel + 1); + return dispatchIfNeeded(callStack, curLevel + 1); } private int getCountForList(FieldValueInfo fieldValueInfo) { @@ -215,8 +223,12 @@ public void onCompleted(ExecutionResult result, Throwable t) { @Override public void onFieldValueInfo(FieldValueInfo fieldValueInfo) { + boolean dispatchNeeded; synchronized (callStack) { - handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level); + dispatchNeeded = handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level); + } + if (dispatchNeeded) { + dispatch(); } } }; @@ -230,10 +242,15 @@ public InstrumentationContext beginFieldFetch(InstrumentationFieldFetchP @Override public void onDispatched(CompletableFuture result) { + boolean dispatchNeeded; synchronized (callStack) { callStack.increaseFetchCount(level); - dispatchIfNeeded(callStack, level); + dispatchNeeded = dispatchIfNeeded(callStack, level); } + if (dispatchNeeded) { + dispatch(); + } + } @Override @@ -246,10 +263,11 @@ public void onCompleted(Object result, Throwable t) { // // thread safety : called with synchronised(callStack) // - private void dispatchIfNeeded(CallStack callStack, int level) { + private boolean dispatchIfNeeded(CallStack callStack, int level) { if (levelReady(callStack, level)) { - callStack.dispatchIfNotDispatchedBefore(level, this::dispatch); + return callStack.dispatchIfNotDispatchedBefore(level); } + return false; } // diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy new file mode 100644 index 0000000000..1db32d4399 --- /dev/null +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderHangingTest.groovy @@ -0,0 +1,219 @@ +package graphql.execution.instrumentation.dataloader + +import graphql.ExecutionInput +import graphql.ExecutionResult +import graphql.GraphQL +import graphql.TestUtil +import graphql.execution.Async +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.idl.RuntimeWiring +import org.apache.commons.lang3.concurrent.BasicThreadFactory +import org.dataloader.BatchLoader +import org.dataloader.DataLoader +import org.dataloader.DataLoaderOptions +import org.dataloader.DataLoaderRegistry +import spock.lang.Specification + +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionStage +import java.util.concurrent.SynchronousQueue +import java.util.concurrent.ThreadFactory +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit + +import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring + +class DataLoaderHangingTest extends Specification { + + public static final int NUM_OF_REPS = 50 + + def "deadlock attempt"() { + setup: + def sdl = """ + type Album { + id: ID! + title: String! + artist: Artist + songs( + limit: Int, + nextToken: String + ): ModelSongConnection + } + + type Artist { + id: ID! + name: String! + albums( + limit: Int, + nextToken: String + ): ModelAlbumConnection + songs( + limit: Int, + nextToken: String + ): ModelSongConnection + } + + type ModelAlbumConnection { + items: [Album] + nextToken: String + } + + type ModelArtistConnection { + items: [Artist] + nextToken: String + } + + type ModelSongConnection { + items: [Song] + nextToken: String + } + + type Query { + listArtists(limit: Int, nextToken: String): ModelArtistConnection + } + + type Song { + id: ID! + title: String! + artist: Artist + album: Album + } + """ + + ThreadFactory threadFactory = new BasicThreadFactory.Builder() + .namingPattern("resolver-chain-thread-%d").build() + def executor = new ThreadPoolExecutor(15, 15, 0L, + TimeUnit.MILLISECONDS, new SynchronousQueue<>(), threadFactory, + new ThreadPoolExecutor.CallerRunsPolicy()) + + def dataLoaderAlbums = new DataLoader(new BatchLoader>() { + @Override + CompletionStage>> load(List keys) { + return CompletableFuture.supplyAsync({ + def limit = keys.first().getArgument("limit") as Integer + return keys.collect({ k -> + def albums = [] + for (int i = 1; i <= limit; i++) { + albums.add(['id': "artist-$k.source.id-$i", 'title': "album-$i"]) + } + def albumsConnection = ['nextToken': 'album-next', 'items': albums] + return albumsConnection + }) + }, executor) + } + }, DataLoaderOptions.newOptions().setMaxBatchSize(5)) + + def dataLoaderSongs = new DataLoader(new BatchLoader>() { + @Override + CompletionStage>> load(List keys) { + return CompletableFuture.supplyAsync({ + def limit = keys.first().getArgument("limit") as Integer + return keys.collect({ k -> + def songs = [] + for (int i = 1; i <= limit; i++) { + songs.add(['id': "album-$k.source.id-$i", 'title': "song-$i"]) + } + def songsConnection = ['nextToken': 'song-next', 'items': songs] + return songsConnection + }) + }, executor) + } + }, DataLoaderOptions.newOptions().setMaxBatchSize(5)) + + def dataLoaderRegistry = new DataLoaderRegistry() + dataLoaderRegistry.register("artist.albums", dataLoaderAlbums) + dataLoaderRegistry.register("album.songs", dataLoaderSongs) + + + def albumsDf = new MyForwardingDataFetcher(dataLoaderAlbums) + def songsDf = new MyForwardingDataFetcher(dataLoaderSongs) + + def dataFetcherArtists = new DataFetcher() { + @Override + Object get(DataFetchingEnvironment environment) { + def limit = environment.getArgument("limit") as Integer + def artists = [] + for (int i = 1; i <= limit; i++) { + artists.add(['id': "artist-$i", 'name': "artist-$i"]) + } + return ['nextToken': 'artist-next', 'items': artists] + } + } + + def wiring = RuntimeWiring.newRuntimeWiring() + .type(newTypeWiring("Query") + .dataFetcher("listArtists", dataFetcherArtists)) + .type(newTypeWiring("Artist") + .dataFetcher("albums", albumsDf)) + .type(newTypeWiring("Album") + .dataFetcher("songs", songsDf)) + .build() + + def schema = TestUtil.schema(sdl, wiring) + + when: + def graphql = GraphQL.newGraphQL(schema) + .instrumentation(new DataLoaderDispatcherInstrumentation(dataLoaderRegistry)) + .build() + + then: "execution shouldn't hang" + List> futures = [] + for (int i = 0; i < NUM_OF_REPS; i++) { + def result = graphql.executeAsync(ExecutionInput.newExecutionInput() + .query(""" + query getArtistsWithData { + listArtists(limit: 1) { + items { + name + albums(limit: 200) { + items { + title + # Uncommenting the following causes query to timeout + songs(limit: 5) { + nextToken + items { + title + } + } + } + } + } + } + } + """) + .build()) + result.whenComplete({ res, error -> + if (error) { + throw error + } + assert res.errors.empty + }) + // add all futures + futures.add(result) + } + // wait for each future to complete and grab the results + Async.each(futures) + .whenComplete({ results, error -> + if (error) { + throw error + } + results.each { assert it.errors.empty } + }) + .join() + } + + static class MyForwardingDataFetcher implements DataFetcher> { + + private final DataLoader dataLoader + + public MyForwardingDataFetcher(DataLoader dataLoader) { + this.dataLoader = dataLoader + } + + @Override + CompletableFuture get(DataFetchingEnvironment environment) { + return dataLoader.load(environment) + } + } +} \ No newline at end of file