Skip to content

Commit

Permalink
Fix userinfo retriaval with post call
Browse files Browse the repository at this point in the history
  • Loading branch information
bhagyasakalanka committed Aug 14, 2024
1 parent cebcda0 commit 3a72eef
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.oltu.oauth2.common.message.OAuthResponse;
import org.wso2.carbon.identity.oauth.common.OAuth2ErrorCodes;
import org.wso2.carbon.identity.oauth.common.OAuthConstants;
import org.wso2.carbon.identity.oauth.endpoint.OAuthRequestWrapper;
import org.wso2.carbon.identity.oauth.endpoint.user.impl.UserInfoEndpointConfig;
import org.wso2.carbon.identity.oauth.endpoint.user.impl.UserInfoJWTResponse;
import org.wso2.carbon.identity.oauth.user.UserInfoAccessTokenValidator;
Expand All @@ -35,19 +34,14 @@
import org.wso2.carbon.identity.oauth.user.UserInfoResponseBuilder;
import org.wso2.carbon.identity.oauth2.dto.OAuth2TokenValidationResponseDTO;

import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;

Expand All @@ -68,7 +62,6 @@ public class OpenIDConnectUserEndpoint {

@GET
@Path("/")
@Consumes("application/x-www-form-urlencoded")
public Response getUserClaims(@Context HttpServletRequest request) throws OAuthSystemException {

String userInfoResponse;
Expand Down Expand Up @@ -107,12 +100,11 @@ public Response getUserClaims(@Context HttpServletRequest request) throws OAuthS

@POST
@Path("/")
@Consumes("application/x-www-form-urlencoded")
@Produces("application/json")
public Response getUserClaimsPost(@Context HttpServletRequest request, MultivaluedMap<String, String> paramMap)
public Response getUserClaimsPost(@Context HttpServletRequest request)
throws OAuthSystemException {

return getUserClaims(new OAuthRequestWrapper(request, (Map<String, List<String>>) paramMap));
return getUserClaims(request);
}

private ResponseBuilder getResponseBuilderWithCacheControlHeaders() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.commons.lang.StringUtils;
import org.apache.oltu.oauth2.common.error.OAuthError;
import org.wso2.carbon.identity.oauth.endpoint.util.EndpointUtil;
import org.wso2.carbon.identity.oauth.user.UserInfoEndpointException;
import org.wso2.carbon.identity.oauth.user.UserInfoRequestValidator;

Expand All @@ -30,6 +31,7 @@
import java.nio.charset.StandardCharsets;

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.core.HttpHeaders;

/**
Expand All @@ -40,27 +42,63 @@
public class UserInforRequestDefaultValidator implements UserInfoRequestValidator {

private static final String US_ASCII = "US-ASCII";
private static final String ACCESS_TOKEN_PARAM = "access_token";
private static final String ACCESS_TOKEN_PARAM = "access_token=";
private static final String BEARER = "Bearer";
private static final String CONTENT_TYPE_HEADER_VALUE = "application/x-www-form-urlencoded";
public static final String CHARSET = "charset=";

@Override
public String validateRequest(HttpServletRequest request) throws UserInfoEndpointException {

String authzHeaders = request.getHeader(HttpHeaders.AUTHORIZATION);
String accessToken = request.getParameter(ACCESS_TOKEN_PARAM);
if (StringUtils.isBlank(authzHeaders) && StringUtils.isNotBlank(accessToken)) {
return accessToken;
}
if (authzHeaders == null) {
String contentTypeHeaders = request.getHeader(HttpHeaders.CONTENT_TYPE);
// To validate the Content_Type header.
if (StringUtils.isBlank(contentTypeHeaders)) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Authorization or Content-Type header is missing");
}

if (StringUtils.isBlank(authzHeaders)) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST, "Bearer token missing");
}
// Restricting passing the access token via request body in GET requests.
if (HttpMethod.GET.equals(request.getMethod())) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Authorization header is missing");
}
if (contentTypeHeaders.trim().startsWith(CONTENT_TYPE_HEADER_VALUE)) {
String charset = getCharsetFromContentType(contentTypeHeaders);

// Use a default charset if none is provided
Charset encodingCharset;
try {
encodingCharset = charset != null ? Charset.forName(charset) : StandardCharsets.UTF_8;
} catch (IllegalArgumentException e) {
encodingCharset = StandardCharsets.UTF_8;
}
String[] arrAccessToken = new String[2];
String requestBody = EndpointUtil.readRequestBody(request, encodingCharset);
String[] arrAccessTokenNew;
// To check whether the entity-body consist entirely of ASCII [USASCII] characters.
if (!isPureAscii(requestBody)) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Body contains non ASCII characters");
}
if (requestBody.contains(ACCESS_TOKEN_PARAM)) {
arrAccessToken = requestBody.trim().split(ACCESS_TOKEN_PARAM);
if (arrAccessToken[1].contains("&")) {
arrAccessTokenNew = arrAccessToken[1].split("&", 2);
return arrAccessTokenNew[0];
}
}
return arrAccessToken[1];
} else {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Content-Type header is wrong");
}
}
String[] authzHeaderInfo = authzHeaders.trim().split(" ");
if (authzHeaderInfo.length < 2 || !BEARER.equals(authzHeaderInfo[0])) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST, "Bearer token missing");
}

return authzHeaderInfo[1];
}

Expand All @@ -76,4 +114,17 @@ public static boolean isPureAscii(String requestBody) {
}
return true;
}

private String getCharsetFromContentType(String contentTypeHeader) {
// Split the Content-Type header value to extract charset
String[] parts = contentTypeHeader.split(";");

for (String part : parts) {
String trimmedPart = part.trim();
if (trimmedPart.toLowerCase().startsWith(CHARSET)) {
return trimmedPart.substring(CHARSET.length()).trim();
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.oltu.oauth2.as.request.OAuthAuthzRequest;
import org.apache.oltu.oauth2.as.response.OAuthASResponse;
import org.apache.oltu.oauth2.common.OAuth;
import org.apache.oltu.oauth2.common.error.OAuthError;
import org.apache.oltu.oauth2.common.exception.OAuthProblemException;
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.owasp.encoder.Encode;
Expand Down Expand Up @@ -87,6 +88,7 @@
import org.wso2.carbon.identity.oauth.endpoint.message.OAuthMessage;
import org.wso2.carbon.identity.oauth.par.core.ParAuthService;
import org.wso2.carbon.identity.oauth.par.exceptions.ParClientException;
import org.wso2.carbon.identity.oauth.user.UserInfoEndpointException;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2ScopeConsentException;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2ScopeException;
Expand Down Expand Up @@ -128,6 +130,7 @@
import java.lang.reflect.InvocationTargetException;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -136,6 +139,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -2145,4 +2149,25 @@ public static void preHandleParRequest(HttpServletRequest request, Map<String, S
throw new ParClientException(e.getErrorCode(), e.getMessage());
}
}

/**
* Read request body from Servlet request.
*
* @param request Http servlet request.
* @return Request body.
* @throws UserInfoEndpointException If an error occurred while reading the request body.
*/
public static String readRequestBody(HttpServletRequest request, Charset charset) throws UserInfoEndpointException {

StringBuilder stringBuilder = new StringBuilder();
try (Scanner scanner = new Scanner(request.getInputStream(), charset.name())) {
while (scanner.hasNextLine()) {
stringBuilder.append(scanner.nextLine());
}
} catch (IOException e) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Unable to read the request body");
}
return stringBuilder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;

import java.lang.reflect.Method;
import java.util.Enumeration;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -169,21 +168,7 @@ public void testGetUserClaims(String authResponse, String errorMessage, String e
assertEquals(metadataValue2, "[no-cache]", "Values are not equal");
assertNotNull(response);
assertEquals(response.getEntity().toString(), authResponse, "Response values are not same");

when(httpServletRequest.getParameterNames()).thenReturn(new Enumeration<String>() {
@Override
public boolean hasMoreElements() {

return false;
}

@Override
public String nextElement() {

return null;
}
});
openIDConnectUserEndpoint.getUserClaimsPost(httpServletRequest, paramMap);
openIDConnectUserEndpoint.getUserClaimsPost(httpServletRequest);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
import org.wso2.carbon.identity.oauth2.OAuth2TokenValidationService;
import org.wso2.carbon.identity.oauth2.dto.OAuth2TokenValidationResponseDTO;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Scanner;
import java.io.InputStream;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.HttpHeaders;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
Expand All @@ -54,7 +57,6 @@ public class UserInfoISAccessTokenValidatorTest {
private final String token = "ZWx1c3VhcmlvOnlsYWNsYXZl";
private final String basicAuthHeader = "Bearer " + token;
private static String contentTypeHeaderValue = "application/x-www-form-urlencoded";
private Scanner scanner;

@BeforeClass
public void setup() {
Expand Down Expand Up @@ -144,4 +146,43 @@ private void prepareHttpServletRequest(String authorization, String contentType)
when(httpServletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(authorization);
lenient().when(httpServletRequest.getHeader(HttpHeaders.CONTENT_TYPE)).thenReturn(contentType);
}

@DataProvider
public Object[][] requestBody() {

return new Object[][]{{contentTypeHeaderValue, "", null}, {contentTypeHeaderValue, null, null},
{contentTypeHeaderValue, "access_token=" + token, token},
{contentTypeHeaderValue, "access_token=" + token + "&someOtherParam=value", token},
{contentTypeHeaderValue, "otherParam=value2&access_token=" + token + "&someOtherParam=value", token}};
}

@Test(dataProvider = "requestBody")
public void testValidateTokenWithRequestBodySuccess(String contentType, String requestBody, String expected)
throws Exception {

String token = testValidateTokenWithRequestBody(contentType, requestBody, true);
assertEquals(token, expected, "Expected token did not receive");
}

private String testValidateTokenWithRequestBody(String contentType, String requestBody, boolean mockScanner)
throws Exception {

prepareHttpServletRequest(null, contentType);
if (mockScanner) {
ServletInputStream inputStream = new ServletInputStream() {
private InputStream stream =
new ByteArrayInputStream(requestBody == null ? "".getBytes() : requestBody.getBytes());

@Override
public int read() throws IOException {

return stream.read();
}
};
doReturn(inputStream).when(httpServletRequest).getInputStream();
} else {
when(httpServletRequest.getInputStream()).thenThrow(new IOException());
}
return userInforRequestDefaultValidator.validateRequest(httpServletRequest);
}
}

0 comments on commit 3a72eef

Please sign in to comment.