Skip to content

Commit

Permalink
Merge pull request #13007 from PasanT9/post-beta-v2
Browse files Browse the repository at this point in the history
Refactor AI API policy Gateway implementation
  • Loading branch information
PasanT9 authored Feb 26, 2025
2 parents a7c71c6 + 01fde10 commit ff9cb84
Show file tree
Hide file tree
Showing 14 changed files with 678 additions and 237 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public class APIConstants {
public static final String ENDPOINT_SECURITY_PRODUCTION = "production";
public static final String ENDPOINT_SECURITY_SANDBOX = "sandbox";
public static final String ENDPOINT_CONFIG_SESSION_TIMEOUT = "sessionTimeOut";
public static final String OM_ELEMENT_NAME = "name";

public static class AIAPIConstants {
public static final int MILLISECONDS_IN_SECOND = 1000;
Expand Down Expand Up @@ -101,12 +100,13 @@ public static class AIAPIConstants {
public static final String TRAFFIC_FLOW_DIRECTION_IN = "IN";
public static final String TRAFFIC_FLOW_DIRECTION_OUT = "OUT";
public static final String API_LLM_ENDPOINT = "_API_LLMEndpoint_";
public static final String TARGET_MODEL = "TARGET_MODEL";
public static final String ROUND_ROBIN_CONFIGS = "ROUND_ROBIN_CONFIGS";
public static final String FAILOVER_CONFIGS = "FAILOVER_CONFIGS";
public static final String TARGET_MODEL_ENDPOINT = "TARGET_MODEL_ENDPOINT";
public static final String TARGET_ENDPOINT = "TARGET_ENDPOINT";
public static final String FAILOVER_TARGET_MODEL = "FAILOVER_TARGET_MODEL";
public static final String FAILOVER_TARGET_ENDPOINT = "FAILOVER_TARGET_ENDPOINT";
public static final String FAILOVER_TARGET_MODEL_ENDPOINT = "FAILOVER_TARGET_MODEL_ENDPOINT";
public static final String FAILOVER_CONFIG_MAP = "FAILOVER_CONFIG_MAP";
public static final String SUSPEND_DURATION = "SUSPEND_DURATION";
public static final String ENDPOINT_TIMEOUT = "ENDPOINT_TIMEOUT";
public static final String FAILOVER_ENDPOINTS = "FAILOVER_ENDPOINTS";
public static final String REJECT_ENDPOINT = "REJECT";
public static final String DEFAULT_ENDPOINT = "DEFAULT";
Expand All @@ -115,10 +115,10 @@ public static class AIAPIConstants {
public static final String REQUEST_HEADERS = "REQUEST_HEADERS";
public static final String REQUEST_HTTP_METHOD = "REQUEST_HTTP_METHOD";
public static final String REQUEST_REST_URL_POSTFIX = "REQUEST_REST_URL_POSTFIX";
public static final String REQUEST_MODEL = "REQUEST_MODEL";
public static final String CURRENT_ENDPOINT_INDEX = "CURRENT_ENDPOINT_INDEX";
public static final String DEFAULT_PRODUCTION_ENDPOINT_NAME = "DEFAULT PRODUCTION ENDPOINT";
public static final String DEFAULT_SANDBOX_ENDPOINT_NAME = "DEFAULT SANDBOX ENDPOINT";
public static final String ENDPOINT_SEQUENCE = "_EndpointsSeq";
public static final String REQUEST_TIMEOUT = "REQUEST_TIMEOUT";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
public class FailoverPolicyConfigDTO {

private static final long DEFAULT_REQUEST_TIMEOUT = 30000L;
private static final long DEFAULT_SUSPEND_DURATION = 300000L;
private static final long DEFAULT_SUSPEND_DURATION = 0L;
private FailoverPolicyDeploymentConfigDTO production;
private FailoverPolicyDeploymentConfigDTO sandbox;
private Long requestTimeout = DEFAULT_REQUEST_TIMEOUT;
Expand Down Expand Up @@ -119,7 +119,7 @@ public void setSuspendDuration(Long suspendDuration) {
*/
private void validateDuration(Long duration, String fieldName) {

if (duration != null && duration <= 0) {
if (duration != null && duration < 0) {
throw new IllegalArgumentException(fieldName + " must be positive");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
*/
public class RBPolicyConfigDTO {

private static final long DEFAULT_SUSPEND_DURATION = 0L;
private List<ModelEndpointDTO> production;
private List<ModelEndpointDTO> sandbox;
private Long suspendDuration;
private Long suspendDuration = DEFAULT_SUSPEND_DURATION;

/**
* Gets the production model endpoints.
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
import org.wso2.carbon.apimgt.api.gateway.ModelEndpointDTO;
import org.wso2.carbon.apimgt.gateway.internal.DataHolder;
import org.wso2.carbon.apimgt.gateway.utils.GatewayUtils;
import org.wso2.carbon.apimgt.impl.APIConstants;
import org.wso2.carbon.apimgt.api.APIConstants.AIAPIConstants;

import java.util.List;
import java.util.HashMap;
import java.util.Map;

/**
* Mediator responsible for handling AI API failover policies. This mediator processes failover configurations,
Expand Down Expand Up @@ -91,31 +91,20 @@ public boolean mediate(MessageContext messageContext) {
return false;
}

String apiKeyType = (String) messageContext.getProperty(APIConstants.API_KEY_TYPE);

FailoverPolicyDeploymentConfigDTO targetConfig =
APIConstants.API_KEY_TYPE_PRODUCTION.equals(apiKeyType)
? policyConfig.getProduction()
: policyConfig.getSandbox();

if ((targetConfig == null || targetConfig.getFallbackModelEndpoints() == null
|| targetConfig.getFallbackModelEndpoints().isEmpty())) {
log.debug("Failover policy is not set for " + apiKeyType + ", bypassing mediation.");
FailoverPolicyDeploymentConfigDTO targetConfig = GatewayUtils.getTargetConfig(messageContext, policyConfig);
if (targetConfig == null) {
return true;
}

List<ModelEndpointDTO> activeEndpoints =
GatewayUtils.filterActiveEndpoints(targetConfig.getFallbackModelEndpoints(), messageContext);

ModelEndpointDTO targetEndpointModel = targetConfig.getTargetModelEndpoint();
Map<String, FailoverPolicyConfigDTO> failoverConfigMap =
(Map<String, FailoverPolicyConfigDTO>) messageContext.getProperty(AIAPIConstants.FAILOVER_CONFIG_MAP);
if (failoverConfigMap == null) {
failoverConfigMap = new HashMap<>();
}

messageContext.setProperty(AIAPIConstants.FAILOVER_TARGET_ENDPOINT, targetEndpointModel.getEndpointId());
messageContext.setProperty(AIAPIConstants.FAILOVER_TARGET_MODEL, targetEndpointModel.getModel());
messageContext.setProperty(AIAPIConstants.FAILOVER_ENDPOINTS, activeEndpoints);
messageContext.setProperty(AIAPIConstants.SUSPEND_DURATION,
policyConfig.getSuspendDuration() * AIAPIConstants.MILLISECONDS_IN_SECOND);
messageContext.setProperty(AIAPIConstants.ENDPOINT_TIMEOUT,
policyConfig.getRequestTimeout() * AIAPIConstants.MILLISECONDS_IN_SECOND);
ModelEndpointDTO targetModelEndpoint = targetConfig.getTargetModelEndpoint();
failoverConfigMap.put(targetModelEndpoint.getModel(), policyConfig);
messageContext.setProperty(AIAPIConstants.FAILOVER_CONFIG_MAP, failoverConfigMap);

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import org.apache.synapse.MessageContext;
import org.apache.synapse.core.SynapseEnvironment;
import org.apache.synapse.mediators.AbstractMediator;
import org.wso2.carbon.apimgt.api.APIConstants;
import org.wso2.carbon.apimgt.api.APIConstants.AIAPIConstants;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.gateway.ModelEndpointDTO;
import org.wso2.carbon.apimgt.api.gateway.RBPolicyConfigDTO;
import org.wso2.carbon.apimgt.gateway.internal.DataHolder;
import org.wso2.carbon.apimgt.gateway.utils.GatewayUtils;
import org.wso2.carbon.apimgt.impl.APIConstants;
import org.wso2.carbon.apimgt.impl.utils.APIUtil;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

/**
Expand Down Expand Up @@ -74,34 +79,34 @@ public boolean mediate(MessageContext messageContext) {
endpoints = new Gson().fromJson(roundRobinConfigs, RBPolicyConfigDTO.class);
} catch (JsonSyntaxException e) {
log.error("Failed to parse weighted round robin configuration", e);
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT,
APIConstants.AIAPIConstants.REJECT_ENDPOINT);
return false;
}

String apiKeyType = (String) messageContext.getProperty(org.wso2.carbon.apimgt.impl.APIConstants.API_KEY_TYPE);
String apiKeyType = (String) messageContext.getProperty(APIConstants.API_KEY_TYPE);

List<ModelEndpointDTO> selectedEndpoints = org.wso2.carbon.apimgt.impl.APIConstants.API_KEY_TYPE_PRODUCTION
List<ModelEndpointDTO> selectedEndpoints = APIConstants.API_KEY_TYPE_PRODUCTION
.equals(apiKeyType)
? endpoints.getProduction()
: endpoints.getSandbox();

if (selectedEndpoints == null || selectedEndpoints.isEmpty()) {
log.debug("RoundRobin policy is not set for " + apiKeyType + ", bypassing mediation.");
if (log.isDebugEnabled()) {
log.debug("RoundRobin policy is not set for " + apiKeyType + ", bypassing mediation.");
}
return true;
}

List<ModelEndpointDTO> activeEndpoints = GatewayUtils.filterActiveEndpoints(selectedEndpoints, messageContext);

if (!activeEndpoints.isEmpty()) {
ModelEndpointDTO nextEndpoint = getRoundRobinEndpoint(activeEndpoints);
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT, nextEndpoint.getEndpointId());
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_MODEL, nextEndpoint.getModel());
messageContext.setProperty(APIConstants.AIAPIConstants.SUSPEND_DURATION,
endpoints.getSuspendDuration() * APIConstants.AIAPIConstants.MILLISECONDS_IN_SECOND);
Map<String, Object> roundRobinConfigs = new HashMap<>();
roundRobinConfigs.put(AIAPIConstants.TARGET_MODEL_ENDPOINT, nextEndpoint);
roundRobinConfigs.put(AIAPIConstants.SUSPEND_DURATION,
endpoints.getSuspendDuration() * AIAPIConstants.MILLISECONDS_IN_SECOND);
messageContext.setProperty(AIAPIConstants.ROUND_ROBIN_CONFIGS, roundRobinConfigs);
} else {
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT,
APIConstants.AIAPIConstants.REJECT_ENDPOINT);
messageContext.setProperty(AIAPIConstants.TARGET_ENDPOINT, AIAPIConstants.REJECT_ENDPOINT);
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@
import org.apache.synapse.MessageContext;
import org.apache.synapse.core.SynapseEnvironment;
import org.apache.synapse.mediators.AbstractMediator;
import org.wso2.carbon.apimgt.api.APIConstants;
import org.wso2.carbon.apimgt.api.APIConstants.AIAPIConstants;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.gateway.ModelEndpointDTO;
import org.wso2.carbon.apimgt.api.gateway.RBPolicyConfigDTO;
import org.wso2.carbon.apimgt.gateway.internal.DataHolder;
import org.wso2.carbon.apimgt.gateway.utils.GatewayUtils;
import org.wso2.carbon.apimgt.impl.APIConstants;
import org.wso2.carbon.apimgt.impl.utils.APIUtil;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
Expand Down Expand Up @@ -74,34 +79,34 @@ public boolean mediate(MessageContext messageContext) {
endpoints = new Gson().fromJson(weightedRoundRobinConfigs, RBPolicyConfigDTO.class);
} catch (JsonSyntaxException e) {
log.error("Failed to parse weighted round robin configuration", e);
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT,
APIConstants.AIAPIConstants.REJECT_ENDPOINT);
return false;
}

String apiKeyType = (String) messageContext.getProperty(org.wso2.carbon.apimgt.impl.APIConstants.API_KEY_TYPE);
String apiKeyType = (String) messageContext.getProperty(APIConstants.API_KEY_TYPE);

List<ModelEndpointDTO> selectedEndpoints = org.wso2.carbon.apimgt.impl.APIConstants.API_KEY_TYPE_PRODUCTION
List<ModelEndpointDTO> selectedEndpoints = APIConstants.API_KEY_TYPE_PRODUCTION
.equals(apiKeyType)
? endpoints.getProduction()
: endpoints.getSandbox();

if (selectedEndpoints == null || selectedEndpoints.isEmpty()) {
log.debug("RoundRobin policy is not set for " + apiKeyType + ", bypassing mediation.");
if (log.isDebugEnabled()) {
log.debug("RoundRobin policy is not set for " + apiKeyType + ", bypassing mediation.");
}
return true;
}

List<ModelEndpointDTO> activeEndpoints = GatewayUtils.filterActiveEndpoints(selectedEndpoints, messageContext);

if (activeEndpoints != null && !activeEndpoints.isEmpty()) {
ModelEndpointDTO nextEndpoint = getWeightedRandomEndpoint(activeEndpoints);
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT, nextEndpoint.getEndpointId());
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_MODEL, nextEndpoint.getModel());
messageContext.setProperty(APIConstants.AIAPIConstants.SUSPEND_DURATION,
endpoints.getSuspendDuration() * APIConstants.AIAPIConstants.MILLISECONDS_IN_SECOND);
Map<String, Object> roundRobinConfigs = new HashMap<>();
roundRobinConfigs.put(AIAPIConstants.TARGET_MODEL_ENDPOINT, nextEndpoint);
roundRobinConfigs.put(AIAPIConstants.SUSPEND_DURATION,
endpoints.getSuspendDuration() * AIAPIConstants.MILLISECONDS_IN_SECOND);
messageContext.setProperty(AIAPIConstants.ROUND_ROBIN_CONFIGS, roundRobinConfigs);
} else {
messageContext.setProperty(APIConstants.AIAPIConstants.TARGET_ENDPOINT,
APIConstants.AIAPIConstants.REJECT_ENDPOINT);
messageContext.setProperty(AIAPIConstants.TARGET_ENDPOINT, AIAPIConstants.REJECT_ENDPOINT);
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,36 @@ public static boolean checkForFileBasedApiContexts(String path, String tenantDom
.getGatewayArtifactSynchronizerProperties().getFileBasedApiContexts().contains(path);
}

/**
* Retrieves the appropriate failover policy configuration (Production/Sandbox).
* If no valid configuration is found, logs a debug message and returns null.
*
* @param messageContext The Synapse {@link MessageContext}.
* @param policyConfig The failover policy configuration DTO.
* @return The appropriate {@link FailoverPolicyDeploymentConfigDTO}, or null if invalid.
*/
public static FailoverPolicyDeploymentConfigDTO getTargetConfig(org.apache.synapse.MessageContext messageContext,
FailoverPolicyConfigDTO policyConfig) {

if (policyConfig == null) {
return null;
}

String apiKeyType = (String) messageContext.getProperty(APIConstants.API_KEY_TYPE);
FailoverPolicyDeploymentConfigDTO targetConfig = APIConstants.API_KEY_TYPE_PRODUCTION.equals(apiKeyType)
? policyConfig.getProduction()
: policyConfig.getSandbox();

if (targetConfig == null || targetConfig.getFallbackModelEndpoints() == null
|| targetConfig.getFallbackModelEndpoints().isEmpty()) {
if (log.isDebugEnabled()) {
log.debug("Failover policy is not set");
}
return null;
}
return targetConfig;
}

/**
* Retrieves available endpoints for the given policy configuration.
*
Expand Down Expand Up @@ -1810,5 +1840,4 @@ public static String getEndpointKey(ModelEndpointDTO endpoint) {
}
return endpoint.getEndpointId() + "_" + endpoint.getModel();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,11 @@ public static class AI {
public static final String MARKETPLACE_ASSISTANT_DELETE_API_RESOURCE = "ApiDeleteResource";
public static final String MARKETPLACE_ASSISTANT_API_COUNT_RESOURCE = "ApiCountResource";
public static final String AI_CONFIGURATION = "AiConfiguration";
public static final String AI_CONFIGURATION_FAILOVER_CONFIGURATIONS = "FailoverConfigurations";
public static final String AI_CONFIGURATION_ROUND_ROBIN_CONFIGURATIONS = "RoundRobinConfigurations";
public static final String AI_CONFIGURATION_FAILOVER_CONFIGURATIONS_FAILOVER_ENDPOINTS_LIMIT =
"FailoverEndpointsLimit";
public static final String AI_CONFIGURATION_DEFAULT_REQUEST_TIMEOUT = "DefaultRequestTimout";

public static final String DESIGN_ASSISTANT = "DesignAssistant";
public static final String DESIGN_ASSISTANT_ENABLED = "Enabled";
Expand Down
Loading

0 comments on commit ff9cb84

Please sign in to comment.