Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code reformatting issues in AI APIs #13011

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ private void processInboundRequest(MessageContext messageContext, LLMProviderCon
return;
}

Map<String, Object> roundRobinConfigs = null;
Map<String, Object> roundRobinConfigs;
if (messageContext.getProperty(APIConstants.AIAPIConstants.ROUND_ROBIN_CONFIGS) != null) {
roundRobinConfigs =
(Map<String, Object>) messageContext.getProperty(APIConstants.AIAPIConstants.ROUND_ROBIN_CONFIGS);
Expand All @@ -181,7 +181,7 @@ private void processInboundRequest(MessageContext messageContext, LLMProviderCon
}

if (failoverConfigMap != null && !failoverConfigMap.isEmpty()) {
prepareForFailover(messageContext, providerConfiguration, failoverConfigMap);
initFailover(messageContext, providerConfiguration, failoverConfigMap);
}

}
Expand All @@ -196,9 +196,9 @@ private void processInboundRequest(MessageContext messageContext, LLMProviderCon
* @throws XMLStreamException If an error occurs while processing the XML message.
* @throws IOException If an I/O error occurs during payload handling.
*/
private void prepareForFailover(MessageContext messageContext,
LLMProviderConfiguration providerConfiguration,
Map<String, FailoverPolicyConfigDTO> failoverConfigMap)
private void initFailover(MessageContext messageContext,
LLMProviderConfiguration providerConfiguration,
Map<String, FailoverPolicyConfigDTO> failoverConfigMap)
throws XMLStreamException, IOException, APIManagementException {

org.apache.axis2.context.MessageContext axis2Ctx =
Expand All @@ -211,7 +211,7 @@ private void prepareForFailover(MessageContext messageContext,
if (requestModel == null || failoverConfig == null) {
return;
}
applyFailoverConfig(messageContext, failoverConfig, providerConfiguration);
applyFailoverConfigs(messageContext, failoverConfig, providerConfiguration);
}

/**
Expand Down Expand Up @@ -240,8 +240,8 @@ private String extractRequestModel(LLMProviderMetadata requestModelMetadata,
* @throws IOException If request modification fails.
* @throws APIManagementException If an API management error occurs.
*/
private void applyFailoverConfig(MessageContext messageContext, FailoverPolicyConfigDTO policyConfig,
LLMProviderConfiguration providerConfiguration) throws IOException,
private void applyFailoverConfigs(MessageContext messageContext, FailoverPolicyConfigDTO policyConfig,
LLMProviderConfiguration providerConfiguration) throws IOException,
APIManagementException {

FailoverPolicyDeploymentConfigDTO targetConfig = GatewayUtils.getTargetConfig(messageContext, policyConfig);
Expand All @@ -268,7 +268,7 @@ private void applyFailoverConfig(MessageContext messageContext, FailoverPolicyCo
modifyRequestPayload(failoverEndpoint.getModel(), providerConfiguration, messageContext);
updateTargetEndpoint(messageContext, 1, failoverEndpoint);
}
preserverFailoverPropertiesInMsgContext(messageContext, policyConfig, targetModelEndpoint, failoverEndpoints);
preserveFailoverPropertiesInMsgCtx(messageContext, policyConfig, targetModelEndpoint, failoverEndpoints);
}

/**
Expand All @@ -280,40 +280,40 @@ private void applyFailoverConfig(MessageContext messageContext, FailoverPolicyCo
* @param failoverEndpoints The list of failover endpoints.
* @throws APIManagementException If an API management error occurs.
*/
private void preserverFailoverPropertiesInMsgContext(MessageContext messageContext,
FailoverPolicyConfigDTO policyConfig,
ModelEndpointDTO targetModelEndpoint,
List<ModelEndpointDTO> failoverEndpoints)
private void preserveFailoverPropertiesInMsgCtx(MessageContext messageContext,
FailoverPolicyConfigDTO policyConfig,
ModelEndpointDTO targetModelEndpoint,
List<ModelEndpointDTO> failoverEndpoints)
throws APIManagementException {

Map<String, Object> failoverConfigs = new HashMap<>();
Map<String, Object> failoverConfigurations = new HashMap<>();

failoverConfigs.put(APIConstants.AIAPIConstants.FAILOVER_TARGET_MODEL_ENDPOINT,
failoverConfigurations.put(APIConstants.AIAPIConstants.FAILOVER_TARGET_MODEL_ENDPOINT,
targetModelEndpoint);
failoverConfigs.put(APIConstants.AIAPIConstants.FAILOVER_ENDPOINTS,
failoverConfigurations.put(APIConstants.AIAPIConstants.FAILOVER_ENDPOINTS,
failoverEndpoints);
failoverConfigs.put(APIConstants.AIAPIConstants.SUSPEND_DURATION,
failoverConfigurations.put(APIConstants.AIAPIConstants.SUSPEND_DURATION,
policyConfig.getSuspendDuration() * APIConstants.AIAPIConstants.MILLISECONDS_IN_SECOND);
if (policyConfig.getRequestTimeout() != null) {
messageContext.setProperty(APIConstants.AIAPIConstants.REQUEST_TIMEOUT,
policyConfig.getRequestTimeout());
} else {
messageContext.setProperty(APIConstants.AIAPIConstants.REQUEST_TIMEOUT,
APIUtil.getDefaultRequestTimeoutForFailoverConfigurations());
}
org.apache.axis2.context.MessageContext axis2Ctx =

long requestTimeout = (policyConfig.getRequestTimeout() != null)
? policyConfig.getRequestTimeout()
: APIUtil.getDefaultRequestTimeoutForFailoverConfigurations();
messageContext.setProperty(APIConstants.AIAPIConstants.REQUEST_TIMEOUT,
requestTimeout);

org.apache.axis2.context.MessageContext axis2MessageContext =
((Axis2MessageContext) messageContext).getAxis2MessageContext();

failoverConfigs.put(APIConstants.AIAPIConstants.REQUEST_PAYLOAD,
JsonUtil.jsonPayloadToString(axis2Ctx));
failoverConfigs.put(APIConstants.AIAPIConstants.REQUEST_HEADERS,
axis2Ctx.getProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS));
failoverConfigs.put(APIConstants.AIAPIConstants.REQUEST_HTTP_METHOD,
axis2Ctx.getProperty(PassThroughConstants.HTTP_METHOD));
failoverConfigs.put(APIConstants.AIAPIConstants.REQUEST_REST_URL_POSTFIX,
axis2Ctx.getProperty(NhttpConstants.REST_URL_POSTFIX));
failoverConfigurations.put(APIConstants.AIAPIConstants.REQUEST_PAYLOAD,
JsonUtil.jsonPayloadToString(axis2MessageContext));
failoverConfigurations.put(APIConstants.AIAPIConstants.REQUEST_HEADERS,
axis2MessageContext.getProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS));
failoverConfigurations.put(APIConstants.AIAPIConstants.REQUEST_HTTP_METHOD,
axis2MessageContext.getProperty(PassThroughConstants.HTTP_METHOD));
failoverConfigurations.put(APIConstants.AIAPIConstants.REQUEST_REST_URL_POSTFIX,
axis2MessageContext.getProperty(NhttpConstants.REST_URL_POSTFIX));

messageContext.setProperty(APIConstants.AIAPIConstants.FAILOVER_CONFIGS, failoverConfigs);
messageContext.setProperty(APIConstants.AIAPIConstants.FAILOVER_CONFIGS, failoverConfigurations);
}

/**
Expand All @@ -336,20 +336,18 @@ private void handleLoadBalancing(
return;
}

if (APIConstants.AIAPIConstants.INPUT_SOURCE_PAYLOAD.equalsIgnoreCase(targetModelMetadata.getInputSource())) {
ModelEndpointDTO targetModelEndpoint =
(ModelEndpointDTO) roundRobinConfigs.get(APIConstants.AIAPIConstants.TARGET_MODEL_ENDPOINT);

if (APIConstants.AIAPIConstants.INPUT_SOURCE_PAYLOAD.equalsIgnoreCase(targetModelMetadata.getInputSource())) {
org.apache.axis2.context.MessageContext axis2Ctx =
((Axis2MessageContext) messageContext).getAxis2MessageContext();
RelayUtils.buildMessage(axis2Ctx);
ModelEndpointDTO targetModel =
(ModelEndpointDTO) roundRobinConfigs.get(APIConstants.AIAPIConstants.TARGET_MODEL_ENDPOINT);
modifyRequestPayload(targetModel.getModel(), targetModelMetadata, axis2Ctx);
modifyRequestPayload(targetModelEndpoint.getModel(), targetModelMetadata, axis2Ctx);
} else {
log.debug("Unsupported input source for attribute: " + targetModelMetadata.getAttributeName());
}

ModelEndpointDTO targetModelEndpoint =
(ModelEndpointDTO) roundRobinConfigs.get(APIConstants.AIAPIConstants.TARGET_MODEL_ENDPOINT);
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT, targetModelEndpoint.getEndpointId());
}

Expand Down Expand Up @@ -485,10 +483,7 @@ private void processOutboundResponse(MessageContext messageContext,
(int) ((Axis2MessageContext) messageContext).getAxis2MessageContext()
.getProperty(APIMgtGatewayConstants.HTTP_SC);

String targetEndpoint = (String) messageContext.getProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT);

if (handleSuccessfulResponse(messageContext, statusCode, providerConfigs, targetEndpoint,
roundRobinConfigs, failoverConfigs)) {
if (handleSuccessfulResponse(messageContext, statusCode, providerConfigs, roundRobinConfigs, failoverConfigs)) {
return;
}

Expand Down Expand Up @@ -519,13 +514,12 @@ private void processOutboundResponse(MessageContext messageContext,
* @param messageContext The message context containing the request and response data.
* @param statusCode The HTTP status code of the response.
* @param providerConfiguration The LLM provider configuration used for fetching token metadata.
* @param targetEndpoint The target endpoint for the current request.
* @param roundRobinConfigs The configuration for round robin load balancing.
* @param failoverConfigs The configuration for failover handling.
* @return True if the response is successful and further processing is done, false otherwise.
*/
private boolean handleSuccessfulResponse(MessageContext messageContext, int statusCode,
LLMProviderConfiguration providerConfiguration, String targetEndpoint,
LLMProviderConfiguration providerConfiguration,
Map<String, Object> roundRobinConfigs,
Map<String, Object> failoverConfigs) {

Expand All @@ -547,7 +541,8 @@ private boolean handleSuccessfulResponse(MessageContext messageContext, int stat
Long suspendDuration = (Long) roundRobinConfigs
.get(APIConstants.AIAPIConstants.SUSPEND_DURATION);

suspendTargetEndpoint(messageContext, targetEndpoint, targetModelEndpoint.getModel(),
suspendTargetEndpoint(messageContext, targetModelEndpoint.getEndpointId(),
targetModelEndpoint.getModel(),
suspendDuration);
} else if (failoverConfigs != null) {
int currentEndpointIndex = getCurrentFailoverIndex(messageContext);
Expand Down Expand Up @@ -702,7 +697,9 @@ private void updateJsonPayloadWithRequestPayload(MessageContext messageContext,
* @throws IOException If an I/O error occurs during payload modification.
*/
private void modifyRequestPayload(String failoverModel,
LLMProviderConfiguration providerConfiguration, MessageContext messageContext) throws IOException {
LLMProviderConfiguration providerConfiguration,
MessageContext messageContext)
throws IOException {

LLMProviderMetadata targetModelMetadata = getTargetModelMetadata(providerConfiguration);
if (targetModelMetadata == null) {
Expand Down
Loading