Skip to content

Commit

Permalink
Merge pull request #154 from graphql-java/reactive-streams-common-pub…
Browse files Browse the repository at this point in the history
…lisher-impl

Making the Subscribers use a common base class
  • Loading branch information
bbakerman authored May 22, 2024
2 parents 8295396 + 6d3c4eb commit 3fddb8b
Showing 1 changed file with 104 additions and 87 deletions.
191 changes: 104 additions & 87 deletions src/main/java/org/dataloader/DataLoaderHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,13 @@ CompletableFuture<V> load(K key, Object loadContext) {
}
}

@SuppressWarnings("unchecked")
Object getCacheKey(K key) {
return loaderOptions.cacheKeyFunction().isPresent() ?
loaderOptions.cacheKeyFunction().get().getKey(key) : key;
}

@SuppressWarnings("unchecked")
Object getCacheKeyWithContext(K key, Object context) {
return loaderOptions.cacheKeyFunction().isPresent() ?
loaderOptions.cacheKeyFunction().get().getKeyWithContext(key, context) : key;
Expand Down Expand Up @@ -511,6 +513,7 @@ private CompletableFuture<List<V>> invokeBatchPublisher(List<K> keys, List<Objec

BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler();
if (batchLoadFunction instanceof BatchPublisherWithContext) {
//noinspection unchecked
BatchPublisherWithContext<K, V> loadFunction = (BatchPublisherWithContext<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber, environment);
Expand All @@ -519,6 +522,7 @@ private CompletableFuture<List<V>> invokeBatchPublisher(List<K> keys, List<Objec
loadFunction.load(keys, subscriber, environment);
}
} else {
//noinspection unchecked
BatchPublisher<K, V> loadFunction = (BatchPublisher<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber);
Expand All @@ -536,6 +540,7 @@ private CompletableFuture<List<V>> invokeMappedBatchPublisher(List<K> keys, List

BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler();
if (batchLoadFunction instanceof MappedBatchPublisherWithContext) {
//noinspection unchecked
MappedBatchPublisherWithContext<K, V> loadFunction = (MappedBatchPublisherWithContext<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber, environment);
Expand All @@ -544,6 +549,7 @@ private CompletableFuture<List<V>> invokeMappedBatchPublisher(List<K> keys, List
loadFunction.load(keys, subscriber, environment);
}
} else {
//noinspection unchecked
MappedBatchPublisher<K, V> loadFunction = (MappedBatchPublisher<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber);
Expand Down Expand Up @@ -618,24 +624,23 @@ private static <T> DispatchResult<T> emptyDispatchResult() {
return (DispatchResult<T>) EMPTY_DISPATCH_RESULT;
}

private class DataLoaderSubscriber implements Subscriber<V> {
private abstract class DataLoaderSubscriberBase<T> implements Subscriber<T> {

private final CompletableFuture<List<V>> valuesFuture;
private final List<K> keys;
private final List<Object> callContexts;
private final List<CompletableFuture<V>> queuedFutures;
final CompletableFuture<List<V>> valuesFuture;
final List<K> keys;
final List<Object> callContexts;
final List<CompletableFuture<V>> queuedFutures;

private final List<K> clearCacheKeys = new ArrayList<>();
private final List<V> completedValues = new ArrayList<>();
private int idx = 0;
private boolean onErrorCalled = false;
private boolean onCompleteCalled = false;
List<K> clearCacheKeys = new ArrayList<>();
List<V> completedValues = new ArrayList<>();
boolean onErrorCalled = false;
boolean onCompleteCalled = false;

private DataLoaderSubscriber(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
DataLoaderSubscriberBase(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
) {
this.valuesFuture = valuesFuture;
this.keys = keys;
Expand All @@ -648,55 +653,97 @@ public void onSubscribe(Subscription subscription) {
subscription.request(keys.size());
}

// onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee
// correctness (at the cost of speed).
@Override
public synchronized void onNext(V value) {
public void onNext(T v) {
assertState(!onErrorCalled, () -> "onError has already been called; onNext may not be invoked.");
assertState(!onCompleteCalled, () -> "onComplete has already been called; onNext may not be invoked.");
}

K key = keys.get(idx);
Object callContext = callContexts.get(idx);
CompletableFuture<V> future = queuedFutures.get(idx);
@Override
public void onComplete() {
assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked.");
onCompleteCalled = true;
}

@Override
public void onError(Throwable throwable) {
assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked.");
onErrorCalled = true;

stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
}

/*
* A value has arrived - how do we complete the future that's associated with it in a common way
*/
void onNextValue(K key, V value, Object callContext, List<CompletableFuture<V>> futures) {
if (value instanceof Try) {
// we allow the batch loader to return a Try so we can better represent a computation
// that might have worked or not.
//noinspection unchecked
Try<V> tryValue = (Try<V>) value;
if (tryValue.isSuccess()) {
future.complete(tryValue.get());
futures.forEach(f -> f.complete(tryValue.get()));
} else {
stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext));
future.completeExceptionally(tryValue.getThrowable());
clearCacheKeys.add(keys.get(idx));
futures.forEach(f -> f.completeExceptionally(tryValue.getThrowable()));
clearCacheKeys.add(key);
}
} else {
future.complete(value);
futures.forEach(f -> f.complete(value));
}
}

Throwable unwrapThrowable(Throwable ex) {
if (ex instanceof CompletionException) {
ex = ex.getCause();
}
return ex;
}
}

private class DataLoaderSubscriber extends DataLoaderSubscriberBase<V> {

private int idx = 0;

private DataLoaderSubscriber(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
) {
super(valuesFuture, keys, callContexts, queuedFutures);
}

// onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee
// correctness (at the cost of speed).
@Override
public synchronized void onNext(V value) {
super.onNext(value);

K key = keys.get(idx);
Object callContext = callContexts.get(idx);
CompletableFuture<V> future = queuedFutures.get(idx);
onNextValue(key, value, callContext, List.of(future));

completedValues.add(value);
idx++;
}

@Override
public void onComplete() {
assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked.");
onCompleteCalled = true;

@Override
public synchronized void onComplete() {
super.onComplete();
assertResultSize(keys, completedValues);

possiblyClearCacheEntriesOnExceptions(clearCacheKeys);
valuesFuture.complete(completedValues);
}

@Override
public void onError(Throwable ex) {
assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked.");
onErrorCalled = true;

stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
if (ex instanceof CompletionException) {
ex = ex.getCause();
}
public synchronized void onError(Throwable ex) {
super.onError(ex);
ex = unwrapThrowable(ex);
// Set the remaining keys to the exception.
for (int i = idx; i < queuedFutures.size(); i++) {
K key = keys.get(i);
Expand All @@ -705,33 +752,25 @@ public void onError(Throwable ex) {
// clear any cached view of this key because they all failed
dataLoader.clear(key);
}
valuesFuture.completeExceptionally(ex);
}

}

private class DataLoaderMapEntrySubscriber implements Subscriber<Map.Entry<K, V>> {
private final CompletableFuture<List<V>> valuesFuture;
private final List<K> keys;
private final List<Object> callContexts;
private final List<CompletableFuture<V>> queuedFutures;
private class DataLoaderMapEntrySubscriber extends DataLoaderSubscriberBase<Map.Entry<K, V>> {

private final Map<K, Object> callContextByKey;
private final Map<K, List<CompletableFuture<V>>> queuedFuturesByKey;

private final List<K> clearCacheKeys = new ArrayList<>();
private final Map<K, V> completedValuesByKey = new HashMap<>();
private boolean onErrorCalled = false;
private boolean onCompleteCalled = false;


private DataLoaderMapEntrySubscriber(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
) {
this.valuesFuture = valuesFuture;
this.keys = keys;
this.callContexts = callContexts;
this.queuedFutures = queuedFutures;

super(valuesFuture, keys, callContexts, queuedFutures);
this.callContextByKey = new HashMap<>();
this.queuedFuturesByKey = new HashMap<>();
for (int idx = 0; idx < queuedFutures.size(); idx++) {
Expand All @@ -743,42 +782,24 @@ private DataLoaderMapEntrySubscriber(
}
}

@Override
public void onSubscribe(Subscription subscription) {
subscription.request(keys.size());
}

@Override
public void onNext(Map.Entry<K, V> entry) {
assertState(!onErrorCalled, () -> "onError has already been called; onNext may not be invoked.");
assertState(!onCompleteCalled, () -> "onComplete has already been called; onNext may not be invoked.");
public synchronized void onNext(Map.Entry<K, V> entry) {
super.onNext(entry);
K key = entry.getKey();
V value = entry.getValue();

Object callContext = callContextByKey.get(key);
List<CompletableFuture<V>> futures = queuedFuturesByKey.get(key);
if (value instanceof Try) {
// we allow the batch loader to return a Try so we can better represent a computation
// that might have worked or not.
Try<V> tryValue = (Try<V>) value;
if (tryValue.isSuccess()) {
futures.forEach(f -> f.complete(tryValue.get()));
} else {
stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext));
futures.forEach(f -> f.completeExceptionally(tryValue.getThrowable()));
clearCacheKeys.add(key);
}
} else {
futures.forEach(f -> f.complete(value));
}

onNextValue(key, value, callContext, futures);

completedValuesByKey.put(key, value);
}

@Override
public void onComplete() {
assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked.");
onCompleteCalled = true;
public synchronized void onComplete() {
super.onComplete();

possiblyClearCacheEntriesOnExceptions(clearCacheKeys);
List<V> values = new ArrayList<>(keys.size());
Expand All @@ -790,14 +811,9 @@ public void onComplete() {
}

@Override
public void onError(Throwable ex) {
assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked.");
onErrorCalled = true;

stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
if (ex instanceof CompletionException) {
ex = ex.getCause();
}
public synchronized void onError(Throwable ex) {
super.onError(ex);
ex = unwrapThrowable(ex);
// Complete the futures for the remaining keys with the exception.
for (int idx = 0; idx < queuedFutures.size(); idx++) {
K key = keys.get(idx);
Expand All @@ -810,6 +826,7 @@ public void onError(Throwable ex) {
dataLoader.clear(key);
}
}
valuesFuture.completeExceptionally(ex);
}
}
}

0 comments on commit 3fddb8b

Please sign in to comment.