diff --git a/core/src/main/java/org/apache/cxf/endpoint/Client.java b/core/src/main/java/org/apache/cxf/endpoint/Client.java index 27ec1b34259..5f55b2fd097 100644 --- a/core/src/main/java/org/apache/cxf/endpoint/Client.java +++ b/core/src/main/java/org/apache/cxf/endpoint/Client.java @@ -283,4 +283,21 @@ void invoke(ClientCallback callback, * @return the Bus */ Bus getBus(); + + /** + * if {@link Client#isThreadLocalRequestContext()} it will clean the request context map. This method + * allows to avoid potential leaks programmatically, instead of waiting the garbage collector intervention. + */ + default void clearThreadLocalRequestContexts() { + + } + + /** + * It will clean the response context map. This method allows to avoid potential leaks programmatically, instead + * of waiting the garbage collector intervention. + */ + default void clearThreadLocalResponseContexts() { + + } + } diff --git a/core/src/main/java/org/apache/cxf/endpoint/ClientImpl.java b/core/src/main/java/org/apache/cxf/endpoint/ClientImpl.java index 6beecacd71e..800ed4d07e3 100644 --- a/core/src/main/java/org/apache/cxf/endpoint/ClientImpl.java +++ b/core/src/main/java/org/apache/cxf/endpoint/ClientImpl.java @@ -71,6 +71,10 @@ import org.apache.cxf.transport.MessageObserver; import org.apache.cxf.workqueue.SynchronousExecutor; +import static java.lang.Thread.currentThread; + + + public class ClientImpl extends AbstractBasicInterceptorProvider implements Client, Retryable, MessageObserver { @@ -96,11 +100,11 @@ public class ClientImpl protected Map currentRequestContext = new ConcurrentHashMap(8, 0.75f, 4); protected Thread latestContextThread; - protected Map> requestContext - = Collections.synchronizedMap(new WeakHashMap>()); + protected Map> requestContext = Collections + .synchronizedMap(new WeakHashMap>()); - protected Map> responseContext - = Collections.synchronizedMap(new WeakHashMap>()); + protected Map> responseContext = Collections + .synchronizedMap(new WeakHashMap>()); protected Executor executor; @@ -208,7 +212,7 @@ private EndpointInfo findEndpoint(Service svc, QName port) { } return epfo; } - + for (ServiceInfo svcfo : svc.getServiceInfos()) { for (EndpointInfo e : svcfo.getEndpoints()) { BindingInfo bfo = e.getBinding(); @@ -240,26 +244,26 @@ public Endpoint getEndpoint() { public Map getRequestContext() { if (isThreadLocalRequestContext()) { - final Thread t = Thread.currentThread(); - if (!requestContext.containsKey(t)) { + if (!requestContext.containsKey(getThreadName())) { Map freshRequestContext = new EchoContext(currentRequestContext); - requestContext.put(t, freshRequestContext); + requestContext.put(getThreadName(), freshRequestContext); } - latestContextThread = t; - return requestContext.get(t); + latestContextThread = currentThread(); + return requestContext.get(getThreadName()); } return currentRequestContext; } + public Map getResponseContext() { - if (!responseContext.containsKey(Thread.currentThread())) { - final Thread t = Thread.currentThread(); - responseContext.put(t, new HashMap() { + if (!responseContext.containsKey(currentThread().getName())) { + responseContext.put(getThreadName(), new HashMap() { private static final long serialVersionUID = 1L; + @Override public void clear() { super.clear(); try { - for (Map.Entry> ent : responseContext.entrySet()) { + for (Map.Entry> ent : responseContext.entrySet()) { if (ent.getValue() == this) { responseContext.remove(ent.getKey()); return; @@ -271,7 +275,7 @@ public void clear() { } }); } - return responseContext.get(Thread.currentThread()); + return responseContext.get(getThreadName()); } public boolean isThreadLocalRequestContext() { @@ -343,7 +347,7 @@ public Object[] invoke(BindingOperationInfo oi, return invoke(oi, params, context, exchange); } finally { if (responseContext != null) { - responseContext.put(Thread.currentThread(), resp); + responseContext.put(getThreadName(), resp); } } } @@ -356,7 +360,7 @@ public Object[] invoke(BindingOperationInfo oi, if (context != null) { Map resp = CastUtils.cast((Map)context.get(RESPONSE_CONTEXT)); if (resp != null && responseContext != null) { - responseContext.put(Thread.currentThread(), resp); + responseContext.put(getThreadName(), resp); } } } @@ -514,7 +518,7 @@ public void onMessage(Message message) { // handle the right response List resList = null; Message inMsg = message.getExchange().getInMessage(); - Map ctx = responseContext.get(Thread.currentThread()); + Map ctx = responseContext.get(getThreadName()); resList = CastUtils.cast(inMsg.getContent(List.class)); Object[] result = resList == null ? null : resList.toArray(); callback.handleResponse(ctx, result); @@ -639,7 +643,7 @@ protected Object[] processResult(Message message, resContext.putAll(inMsg); // remove the recursive reference if present resContext.remove(Message.INVOCATION_CONTEXT); - responseContext.put(Thread.currentThread(), resContext); + responseContext.put(getThreadName(), resContext); } resList = CastUtils.cast(inMsg.getContent(List.class)); } @@ -651,9 +655,9 @@ protected Object[] processResult(Message message, throw ex; } - if (resList == null + if (resList == null && oi != null && !oi.getOperationInfo().isOneWay()) { - + BindingOperationInfo boi = oi; if (boi.isUnwrapped()) { boi = boi.getWrappedOperation(); @@ -815,7 +819,7 @@ public void onMessage(Message message) { resCtx = CastUtils.cast((Map) resCtx .get(RESPONSE_CONTEXT)); if (resCtx != null) { - responseContext.put(Thread.currentThread(), resCtx); + responseContext.put(getThreadName(), resCtx); } // remove callback so that it won't be invoked twice callback = message.getExchange().remove(ClientCallback.class); @@ -844,7 +848,7 @@ public void onMessage(Message message) { .get(Message.INVOCATION_CONTEXT)); resCtx = CastUtils.cast((Map)resCtx.get(RESPONSE_CONTEXT)); if (resCtx != null && responseContext != null) { - responseContext.put(Thread.currentThread(), resCtx); + responseContext.put(getThreadName(), resCtx); } try { Object obj[] = processResult(message, message.getExchange(), @@ -1064,14 +1068,14 @@ public EchoContext(Map sharedMap) { public void reload() { super.clear(); - super.putAll(requestContext.get(latestContextThread)); + super.putAll(requestContext.get(latestContextThread.getName())); } - + @Override public void clear() { super.clear(); try { - for (Map.Entry> ent : requestContext.entrySet()) { + for (Map.Entry> ent : requestContext.entrySet()) { if (ent.getValue() == this) { requestContext.remove(ent.getKey()); return; @@ -1083,12 +1087,29 @@ public void clear() { } } - public void setExecutor(Executor executor) { if (!SynchronousExecutor.isA(executor)) { this.executor = executor; } } + /** + * Returns the current thread name. As this value will be used as a key in a {@link WeakHashMap}, + * we should use the String constructor for avoiding scenarios where strong references are kept + * to the string names causing that the map entries can't be garbage collected. + * @return + */ + private String getThreadName() { + return new String(currentThread().getName()); + } + @Override + public void clearThreadLocalRequestContexts() { + requestContext.clear(); + } + + @Override + public void clearThreadLocalResponseContexts() { + responseContext.clear(); + } } diff --git a/core/src/test/java/org/apache/cxf/endpoint/ClientImplTest.java b/core/src/test/java/org/apache/cxf/endpoint/ClientImplTest.java new file mode 100644 index 00000000000..464084be630 --- /dev/null +++ b/core/src/test/java/org/apache/cxf/endpoint/ClientImplTest.java @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.cxf.endpoint; + +import java.util.Map; + +import org.apache.cxf.Bus; + +import org.junit.Before; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +public class ClientImplTest { + + private Bus bus = mock(Bus.class); + private Endpoint endpoint = mock(Endpoint.class); + + @Before + public void setUp() throws Exception { + expect(bus.getExtension(ClientLifeCycleManager.class)).andReturn(null); + expect(bus.getExtension(ClassLoader.class)).andReturn(null); + replay(bus); + } + + @Test + public void requestContextIsThreadLocal() { + Client client = new ClientImpl(bus, endpoint); + client.setThreadLocalRequestContext(true); + Map requestContext = client.getRequestContext(); + assertSame(requestContext, client.getRequestContext()); + } + + @Test + public void requestContextIsGarbageCollected() { + Client client = new ClientImpl(bus, endpoint); + client.setThreadLocalRequestContext(true); + Map requestContext = client.getRequestContext(); + System.gc(); + assertNotSame(requestContext, client.getRequestContext()); + } + + @Test + public void requestContextCleanOnDemand() { + Client client = new ClientImpl(bus, endpoint); + client.setThreadLocalRequestContext(true); + Map requestContext = client.getRequestContext(); + client.clearThreadLocalRequestContexts(); + assertNotSame(requestContext, client.getRequestContext()); + } + + @Test + public void responseContextIsThreadLocal() { + Client client = new ClientImpl(bus, endpoint); + Map requestContext = client.getResponseContext(); + assertSame(requestContext, client.getResponseContext()); + } + + @Test + public void responseContextIsGarbageCollected() { + Client client = new ClientImpl(bus, endpoint); + Map responseContext = client.getResponseContext(); + System.gc(); + assertNotSame(responseContext, client.getResponseContext()); + } + + @Test + public void responseContextCleanOnDemand() { + Client client = new ClientImpl(bus, endpoint); + Map requestContext = client.getResponseContext(); + client.clearThreadLocalResponseContexts(); + assertNotSame(requestContext, client.getResponseContext()); + } +}