Skip to content

Commit

Permalink
Merge pull request #166 from aka4rKO/3.0.0-tokenfilter
Browse files Browse the repository at this point in the history
[OB3] Provide config for making transport cert optional for the token endpoint
  • Loading branch information
Ashi1993 authored Oct 30, 2024
2 parents 1b0c3a3 + b65c5a9 commit 378d608
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ priority=1
name = "TokenFilter"
class = "com.wso2.openbanking.accelerator.identity.token.TokenFilter"

[tomcat.filter.init_params]
isTransportCertificateMandatory = true

[[tomcat.filter_mapping]]
name = "TokenFilter"
url_pattern = "/token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,9 @@ url_pattern = "/oauth2/authorize*"
name = "TokenFilter"
class = "com.wso2.openbanking.accelerator.identity.token.TokenFilter"

[tomcat.filter.init_params]
isTransportCertificateMandatory = true

[[tomcat.filter_mapping]]
name = "TokenFilter"
url_pattern = "/token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ priority=1
name = "TokenFilter"
class = "com.wso2.openbanking.accelerator.identity.token.TokenFilter"

[tomcat.filter.init_params]
isTransportCertificateMandatory = true

[[tomcat.filter_mapping]]
name = "TokenFilter"
url_pattern = "/token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class TokenFilter implements Filter {
private static DefaultTokenFilter defaultTokenFilter;
private String clientId = null;
private static List<OBIdentityFilterValidator> validators = new ArrayList<>();
private boolean isTransportCertMandatory;

private static final String BASIC_AUTH_ERROR_MSG = "Unable to find client id in the request. " +
"Invalid Authorization header found.";
Expand All @@ -73,6 +74,15 @@ public class TokenFilter implements Filter {
public void init(FilterConfig filterConfig) {

ServletContext context = filterConfig.getServletContext();

String isTransportCertMandatoryConf = filterConfig.getInitParameter("isTransportCertificateMandatory");
if (isTransportCertMandatoryConf == null) {
// By default, mandating the transport certificate
isTransportCertMandatory = true;
} else {
isTransportCertMandatory = Boolean.parseBoolean(isTransportCertMandatoryConf);
}

context.log("TokenFilter initialized");
}

Expand Down Expand Up @@ -130,6 +140,11 @@ private ServletRequest appendTransportHeader(ServletRequest request, ServletResp
ServletException, IOException, CertificateEncodingException {

if (request instanceof HttpServletRequest) {

if (!isTransportCertMandatory) {
return request;
}

Object certAttribute = request.getAttribute(IdentityCommonConstants.JAVAX_SERVLET_REQUEST_CERTIFICATE);
String x509Certificate = ((HttpServletRequest) request).getHeader(IdentityCommonUtil.getMTLSAuthHeader());
if (new IdentityCommonHelper().isTransportCertAsHeaderEnabled() && x509Certificate != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
Expand Down Expand Up @@ -106,7 +107,8 @@ public void nonRegulatoryAppWithAuthorizationHeaderTest() throws Exception {
}

@Test(description = "Test the certificate in context/header is mandated")
public void noCertificateTest() throws IOException, OpenBankingException, ServletException {
public void noCertificateTest() throws IOException, OpenBankingException, ServletException, NoSuchFieldException,
IllegalAccessException {

Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
Expand All @@ -120,6 +122,11 @@ public void noCertificateTest() throws IOException, OpenBankingException, Servle
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader())
.thenReturn(IdentityCommonConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand All @@ -129,7 +136,8 @@ public void noCertificateTest() throws IOException, OpenBankingException, Servle
}

@Test(description = "Test the certificate in attribute is present if config is disabled")
public void certificateIsNotPresentInAttributeTest() throws IOException, OpenBankingException, ServletException {
public void certificateIsNotPresentInAttributeTest() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {

Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
Expand All @@ -144,6 +152,11 @@ public void certificateIsNotPresentInAttributeTest() throws IOException, OpenBan
Mockito.doReturn(new DefaultTokenFilter()).when(filter).getDefaultTokenFilter();
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Test constants.
*/
public class TestConstants {
public static final String IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME = "isTransportCertMandatory";
public static final String TARGET_STREAM = "targetStream";
public static final String CERTIFICATE_HEADER = "x-wso2-mutual-auth-cert";
public static final String EXPIRED_CERTIFICATE_CONTENT = "-----BEGIN CERTIFICATE-----" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.lang.reflect.Field;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -123,12 +124,17 @@ public void certificateAttributeValidation() throws Exception {
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);
PowerMockito.when(IdentityCommonUtil.getCertificateFromAttribute(cert)).thenReturn(cert);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
assertEquals(response.getStatus(), HttpServletResponse.SC_OK);
}

@Test(description = "Test whether the certificate header is present")
public void noCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException {
public void noCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {
Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
configMap.put(IdentityCommonConstants.ENABLE_TRANSPORT_CERT_AS_HEADER, true);
Expand All @@ -148,6 +154,10 @@ public void noCertificateHeaderValidation() throws IOException, OpenBankingExcep
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand All @@ -159,7 +169,8 @@ public void noCertificateHeaderValidation() throws IOException, OpenBankingExcep


@Test(description = "Test the certificate in attribute is passed as a header")
public void certificateIsPresentInAttributeTest() throws IOException, OpenBankingException, ServletException {
public void certificateIsPresentInAttributeTest() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {
MTLSEnforcementValidator mtlsEnforcementValidator = Mockito.spy(MTLSEnforcementValidator.class);
PowerMockito.mockStatic(IdentityCommonUtil.class);

Expand All @@ -183,12 +194,17 @@ public void certificateIsPresentInAttributeTest() throws IOException, OpenBankin
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);
PowerMockito.when(IdentityCommonUtil.getCertificateFromAttribute(cert)).thenReturn(cert);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
assertEquals(response.getStatus(), HttpServletResponse.SC_OK);
}

@Test(description = "Test whether the certificate attribute is valid")
public void invalidCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException {
public void invalidCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {
Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
configMap.put(IdentityCommonConstants.ENABLE_TRANSPORT_CERT_AS_HEADER, true);
Expand All @@ -209,6 +225,10 @@ public void invalidCertificateHeaderValidation() throws IOException, OpenBanking
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand Down

0 comments on commit 378d608

Please sign in to comment.