Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix userinfo retriaval with post call #2544

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.oltu.oauth2.common.message.OAuthResponse;
import org.wso2.carbon.identity.oauth.common.OAuth2ErrorCodes;
import org.wso2.carbon.identity.oauth.common.OAuthConstants;
import org.wso2.carbon.identity.oauth.endpoint.OAuthRequestWrapper;
import org.wso2.carbon.identity.oauth.endpoint.user.impl.UserInfoEndpointConfig;
import org.wso2.carbon.identity.oauth.endpoint.user.impl.UserInfoJWTResponse;
import org.wso2.carbon.identity.oauth.user.UserInfoAccessTokenValidator;
Expand All @@ -35,19 +34,14 @@
import org.wso2.carbon.identity.oauth.user.UserInfoResponseBuilder;
import org.wso2.carbon.identity.oauth2.dto.OAuth2TokenValidationResponseDTO;

import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;

Expand All @@ -68,7 +62,6 @@ public class OpenIDConnectUserEndpoint {

@GET
@Path("/")
@Consumes("application/x-www-form-urlencoded")
public Response getUserClaims(@Context HttpServletRequest request) throws OAuthSystemException {

String userInfoResponse;
Expand Down Expand Up @@ -107,12 +100,11 @@ public Response getUserClaims(@Context HttpServletRequest request) throws OAuthS

@POST
@Path("/")
@Consumes("application/x-www-form-urlencoded")
@Produces("application/json")
public Response getUserClaimsPost(@Context HttpServletRequest request, MultivaluedMap<String, String> paramMap)
public Response getUserClaimsPost(@Context HttpServletRequest request)
throws OAuthSystemException {

return getUserClaims(new OAuthRequestWrapper(request, (Map<String, List<String>>) paramMap));
return getUserClaims(request);
}

private ResponseBuilder getResponseBuilderWithCacheControlHeaders() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.commons.lang.StringUtils;
import org.apache.oltu.oauth2.common.error.OAuthError;
import org.wso2.carbon.identity.oauth.endpoint.util.EndpointUtil;
import org.wso2.carbon.identity.oauth.user.UserInfoEndpointException;
import org.wso2.carbon.identity.oauth.user.UserInfoRequestValidator;

Expand All @@ -30,6 +31,7 @@
import java.nio.charset.StandardCharsets;

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.core.HttpHeaders;

/**
Expand All @@ -40,27 +42,63 @@
public class UserInforRequestDefaultValidator implements UserInfoRequestValidator {

private static final String US_ASCII = "US-ASCII";
private static final String ACCESS_TOKEN_PARAM = "access_token";
private static final String ACCESS_TOKEN_PARAM = "access_token=";
private static final String BEARER = "Bearer";
private static final String CONTENT_TYPE_HEADER_VALUE = "application/x-www-form-urlencoded";
public static final String CHARSET = "charset=";

@Override
public String validateRequest(HttpServletRequest request) throws UserInfoEndpointException {

String authzHeaders = request.getHeader(HttpHeaders.AUTHORIZATION);
String accessToken = request.getParameter(ACCESS_TOKEN_PARAM);
if (StringUtils.isBlank(authzHeaders) && StringUtils.isNotBlank(accessToken)) {
return accessToken;
}
if (authzHeaders == null) {
String contentTypeHeaders = request.getHeader(HttpHeaders.CONTENT_TYPE);
// To validate the Content_Type header.
if (StringUtils.isBlank(contentTypeHeaders)) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Authorization or Content-Type header is missing");
}

if (StringUtils.isBlank(authzHeaders)) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST, "Bearer token missing");
}
// Restricting passing the access token via request body in GET requests.
if (HttpMethod.GET.equals(request.getMethod())) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Authorization header is missing");
}
if (contentTypeHeaders.trim().startsWith(CONTENT_TYPE_HEADER_VALUE)) {
String charset = getCharsetFromContentType(contentTypeHeaders);

// Use a default charset if none is provided
Charset encodingCharset;
try {
encodingCharset = charset != null ? Charset.forName(charset) : StandardCharsets.UTF_8;
} catch (IllegalArgumentException e) {
encodingCharset = StandardCharsets.UTF_8;
}
String[] arrAccessToken = new String[2];
String requestBody = EndpointUtil.readRequestBody(request, encodingCharset);
String[] arrAccessTokenNew;
// To check whether the entity-body consist entirely of ASCII [USASCII] characters.
if (!isPureAscii(requestBody)) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Body contains non ASCII characters");
}
if (requestBody.contains(ACCESS_TOKEN_PARAM)) {
arrAccessToken = requestBody.trim().split(ACCESS_TOKEN_PARAM);
if (arrAccessToken[1].contains("&")) {
arrAccessTokenNew = arrAccessToken[1].split("&", 2);
return arrAccessTokenNew[0];
}
}
return arrAccessToken[1];
} else {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Content-Type header is wrong");
}
}
String[] authzHeaderInfo = authzHeaders.trim().split(" ");
if (authzHeaderInfo.length < 2 || !BEARER.equals(authzHeaderInfo[0])) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST, "Bearer token missing");
}

return authzHeaderInfo[1];
}

Expand All @@ -76,4 +114,17 @@ public static boolean isPureAscii(String requestBody) {
}
return true;
}

private String getCharsetFromContentType(String contentTypeHeader) {
// Split the Content-Type header value to extract charset
String[] parts = contentTypeHeader.split(";");

for (String part : parts) {
String trimmedPart = part.trim();
if (trimmedPart.toLowerCase().startsWith(CHARSET)) {
return trimmedPart.substring(CHARSET.length()).trim();
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.oltu.oauth2.as.request.OAuthAuthzRequest;
import org.apache.oltu.oauth2.as.response.OAuthASResponse;
import org.apache.oltu.oauth2.common.OAuth;
import org.apache.oltu.oauth2.common.error.OAuthError;
import org.apache.oltu.oauth2.common.exception.OAuthProblemException;
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.owasp.encoder.Encode;
Expand Down Expand Up @@ -87,6 +88,7 @@
import org.wso2.carbon.identity.oauth.endpoint.message.OAuthMessage;
import org.wso2.carbon.identity.oauth.par.core.ParAuthService;
import org.wso2.carbon.identity.oauth.par.exceptions.ParClientException;
import org.wso2.carbon.identity.oauth.user.UserInfoEndpointException;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2ScopeConsentException;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2ScopeException;
Expand Down Expand Up @@ -128,6 +130,7 @@
import java.lang.reflect.InvocationTargetException;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -136,6 +139,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -2145,4 +2149,25 @@ public static void preHandleParRequest(HttpServletRequest request, Map<String, S
throw new ParClientException(e.getErrorCode(), e.getMessage());
}
}

/**
* Read request body from Servlet request.
*
* @param request Http servlet request.
* @return Request body.
* @throws UserInfoEndpointException If an error occurred while reading the request body.
*/
public static String readRequestBody(HttpServletRequest request, Charset charset) throws UserInfoEndpointException {

StringBuilder stringBuilder = new StringBuilder();
try (Scanner scanner = new Scanner(request.getInputStream(), charset.name())) {
while (scanner.hasNextLine()) {
stringBuilder.append(scanner.nextLine());
}
} catch (IOException e) {
throw new UserInfoEndpointException(OAuthError.ResourceResponse.INVALID_REQUEST,
"Unable to read the request body");
}
return stringBuilder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;

import java.lang.reflect.Method;
import java.util.Enumeration;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -169,21 +168,7 @@ public void testGetUserClaims(String authResponse, String errorMessage, String e
assertEquals(metadataValue2, "[no-cache]", "Values are not equal");
assertNotNull(response);
assertEquals(response.getEntity().toString(), authResponse, "Response values are not same");

when(httpServletRequest.getParameterNames()).thenReturn(new Enumeration<String>() {
@Override
public boolean hasMoreElements() {

return false;
}

@Override
public String nextElement() {

return null;
}
});
openIDConnectUserEndpoint.getUserClaimsPost(httpServletRequest, paramMap);
openIDConnectUserEndpoint.getUserClaimsPost(httpServletRequest);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
import org.wso2.carbon.identity.oauth2.OAuth2TokenValidationService;
import org.wso2.carbon.identity.oauth2.dto.OAuth2TokenValidationResponseDTO;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Scanner;
import java.io.InputStream;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.HttpHeaders;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
Expand All @@ -54,7 +57,6 @@ public class UserInfoISAccessTokenValidatorTest {
private final String token = "ZWx1c3VhcmlvOnlsYWNsYXZl";
private final String basicAuthHeader = "Bearer " + token;
private static String contentTypeHeaderValue = "application/x-www-form-urlencoded";
private Scanner scanner;

@BeforeClass
public void setup() {
Expand Down Expand Up @@ -144,4 +146,43 @@ private void prepareHttpServletRequest(String authorization, String contentType)
when(httpServletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(authorization);
lenient().when(httpServletRequest.getHeader(HttpHeaders.CONTENT_TYPE)).thenReturn(contentType);
}

@DataProvider
public Object[][] requestBody() {

return new Object[][]{{contentTypeHeaderValue, "", null}, {contentTypeHeaderValue, null, null},
{contentTypeHeaderValue, "access_token=" + token, token},
{contentTypeHeaderValue, "access_token=" + token + "&someOtherParam=value", token},
{contentTypeHeaderValue, "otherParam=value2&access_token=" + token + "&someOtherParam=value", token}};
}

@Test(dataProvider = "requestBody")
public void testValidateTokenWithRequestBodySuccess(String contentType, String requestBody, String expected)
throws Exception {

String token = testValidateTokenWithRequestBody(contentType, requestBody, true);
assertEquals(token, expected, "Expected token did not receive");
}

private String testValidateTokenWithRequestBody(String contentType, String requestBody, boolean mockScanner)
throws Exception {

prepareHttpServletRequest(null, contentType);
if (mockScanner) {
ServletInputStream inputStream = new ServletInputStream() {
private InputStream stream =
new ByteArrayInputStream(requestBody == null ? "".getBytes() : requestBody.getBytes());

@Override
public int read() throws IOException {

return stream.read();
}
};
doReturn(inputStream).when(httpServletRequest).getInputStream();
} else {
when(httpServletRequest.getInputStream()).thenThrow(new IOException());
}
return userInforRequestDefaultValidator.validateRequest(httpServletRequest);
}
}
Loading