Skip to content

Commit

Permalink
Implement AI token generation using the key
Browse files Browse the repository at this point in the history
  • Loading branch information
piyumaldk committed Sep 9, 2024
1 parent c1793d5 commit f97c812
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3344,7 +3344,7 @@ public String getOpenAPIDefinition(String apiId, String organization) throws API
public String invokeApiChatExecute(String apiChatRequestId, String requestPayload) throws APIManagementException {
ApiChatConfigurationDTO configDto = ServiceReferenceHolder.getInstance().getAPIManagerConfigurationService()
.getAPIManagerConfiguration().getApiChatConfigurationDto();
return APIUtil.invokeAIService(configDto.getEndpoint(), configDto.getKey(),
return APIUtil.invokeAIService(configDto.getEndpoint(), configDto.getTokenEndpoint(), configDto.getKey(),
configDto.getExecuteResource(), requestPayload, apiChatRequestId);
}

Expand All @@ -3360,7 +3360,7 @@ public String invokeApiChatPrepare(String apiId, String apiChatRequestId, String

ApiChatConfigurationDTO configDto = ServiceReferenceHolder.getInstance().getAPIManagerConfigurationService()
.getAPIManagerConfiguration().getApiChatConfigurationDto();
return APIUtil.invokeAIService(configDto.getEndpoint(), configDto.getKey(),
return APIUtil.invokeAIService(configDto.getEndpoint(), configDto.getTokenEndpoint(), configDto.getKey(),
configDto.getPrepareResource(), payload.toString(), apiChatRequestId);
} catch (JsonProcessingException e) {
String error = "Error while parsing OpenAPI definition of API ID: " + apiId + " to JSON";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2418,6 +2418,11 @@ public void setMarketplaceAssistantConfiguration(OMElement omElement){
if (marketplaceAssistantEndpoint != null) {
marketplaceAssistantConfigurationDto.setEndpoint(marketplaceAssistantEndpoint.getText());
}
OMElement marketplaceAssistantTokenEndpoint =
omElement.getFirstChildWithName(new QName(APIConstants.AI.MARKETPLACE_ASSISTANT_TOKEN_ENDPOINT));
if (marketplaceAssistantTokenEndpoint != null) {
marketplaceAssistantConfigurationDto.setTokenEndpoint(marketplaceAssistantTokenEndpoint.getText());
}
OMElement marketplaceAssistantToken =
omElement.getFirstChildWithName(new QName(APIConstants.AI.MARKETPLACE_ASSISTANT_KEY));

Expand Down Expand Up @@ -2467,6 +2472,11 @@ public void setApiChatConfiguration(OMElement omElement){
if (apiChatEndpoint != null) {
apiChatConfigurationDto.setEndpoint(apiChatEndpoint.getText());
}
OMElement apiChatTokenEndpoint =
omElement.getFirstChildWithName(new QName(APIConstants.AI.API_CHAT_TOKEN_ENDPOINT));
if (apiChatEndpoint != null) {
apiChatConfigurationDto.setTokenEndpoint(apiChatTokenEndpoint.getText());
}
OMElement apiChatKey =
omElement.getFirstChildWithName(new QName(APIConstants.AI.API_CHAT_KEY));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ public void run() {
payload.put(APIConstants.VERSION, apiEvent.getApiVersion());

APIUtil.invokeAIService(marketplaceAssistantConfigurationDto.getEndpoint(),
marketplaceAssistantConfigurationDto.getTokenEndpoint(),
marketplaceAssistantConfigurationDto.getKey(),
marketplaceAssistantConfigurationDto.getApiPublishResource(), payload.toString(), null);
} catch (APIManagementException e) {
Expand All @@ -226,6 +227,7 @@ public MarketplaceAssistantDeletionTask(String uuid) {
public void run() {
try {
APIUtil.marketplaceAssistantDeleteService(marketplaceAssistantConfigurationDto.getEndpoint(),
marketplaceAssistantConfigurationDto.getTokenEndpoint(),
marketplaceAssistantConfigurationDto.getKey(),
marketplaceAssistantConfigurationDto.getApiDeleteResource(), uuid);
} catch (APIManagementException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
import org.apache.commons.logging.LogFactory;
import org.apache.commons.validator.routines.RegexValidator;
import org.apache.commons.validator.routines.UrlValidator;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
Expand Down Expand Up @@ -252,6 +254,7 @@
import java.net.URLDecoder;
import java.net.UnknownHostException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.rmi.RemoteException;
import java.security.*;
import java.security.cert.Certificate;
Expand Down Expand Up @@ -10410,22 +10413,69 @@ public static String getScopesAsString(Set<Scope> scopes) {
return scopesStringBuilder.toString().trim();
}

/**
* Get the AI token using the key
*/
private static String getAiTokenUsingKey(String tokenEndpoint, String key) throws APIManagementException {
try {
HttpPost tokenRequest = new HttpPost(tokenEndpoint);
tokenRequest.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + key);
tokenRequest.setHeader(HttpHeaders.CONTENT_TYPE, APIConstants.CONTENT_TYPE_APPLICATION_FORM);

List<BasicNameValuePair> urlParameters = new ArrayList<>();
urlParameters.add(new BasicNameValuePair(APIConstants.TOKEN_GRANT_TYPE_KEY, APIConstants.GRANT_TYPE_VALUE));
tokenRequest.setEntity(new UrlEncodedFormEntity(urlParameters, StandardCharsets.UTF_8));

URL url = new URL(tokenEndpoint);
int port = url.getPort();
String protocol = url.getProtocol();
HttpClient httpClient = APIUtil.getHttpClient(port, protocol);

CloseableHttpResponse response = executeHTTPRequest(tokenRequest, httpClient);
int statusCode = response.getStatusLine().getStatusCode();
String responseStr = EntityUtils.toString(response.getEntity());

if (statusCode == HttpStatus.SC_OK) {
org.json.JSONObject responseJson = new org.json.JSONObject(responseStr);
if (responseJson.has("access_token")) {
log.info("AI access token : " + responseJson.getString("access_token"));
return responseJson.getString("access_token");
} else {
throw new APIManagementException("Access token not found in the response.",
ExceptionCodes.AI_SERVICE_INVALID_RESPONSE);
}
} else if (statusCode == HttpStatus.SC_UNAUTHORIZED) {
throw new APIManagementException("Invalid key used when requesting AI token.",
ExceptionCodes.AI_SERVICE_INVALID_RESPONSE);
} else {
throw new APIManagementException("Unexpected response when fetching AI token",
ExceptionCodes.AI_SERVICE_INVALID_RESPONSE);
}
} catch (MalformedURLException e) {
throw new APIManagementException("Invalid/malformed URL encountered. URL: " + tokenEndpoint, e);
} catch (IOException e) {
throw new APIManagementException("Error encountered while connecting to token service", e);
}
}

/**
* This method is used to invoke the Choreo deployed AI service. This can handle both API Chat and Marketplace
* Assistant related POST calls.
*
* @param endpoint Endpoint to be invoked
* @param authToken OnPremKey for the organization
* @param resource Specifies the backend resource the request should be forwarded to
* @param payload Request payload that needs to be attached to the request
* @param requestId UUID of the request, so that AI service can track the progress
* @param endpoint Endpoint to be invoked
* @param tokenEndpoint Endpoint to be invoked to get the token using key
* @param authToken OnPremKey for the organization
* @param resource Specifies the backend resource the request should be forwarded to
* @param payload Request payload that needs to be attached to the request
* @param requestId UUID of the request, so that AI service can track the progress
* @return returns the response if invocation is successful
* @throws APIManagementException if an error occurs while invoking the AI service
*/
public static String invokeAIService(String endpoint, String authToken, String resource, String payload,
String requestId) throws APIManagementException {
public static String invokeAIService(String endpoint, String tokenEndpoint, String authToken, String resource,
String payload, String requestId) throws APIManagementException {

try {
String token = getAiTokenUsingKey(tokenEndpoint, authToken);
HttpPost preparePost = new HttpPost(endpoint + resource);
preparePost.setHeader(APIConstants.API_KEY_AUTH, authToken);
preparePost.setHeader(HttpHeaders.CONTENT_TYPE, APIConstants.APPLICATION_JSON_MEDIA_TYPE);
Expand Down Expand Up @@ -10481,10 +10531,11 @@ public static String invokeAIService(String endpoint, String authToken, String r
* @return CloseableHttpResponse of the GET call
* @throws APIManagementException if an error occurs while retrieving API count
*/
public static CloseableHttpResponse getMarketplaceChatApiCount(String endpoint, String authToken, String resource)
throws APIManagementException {
public static CloseableHttpResponse getMarketplaceChatApiCount(String endpoint, String tokenEndpoint,
String authToken, String resource) throws APIManagementException {

try {
String token = getAiTokenUsingKey(tokenEndpoint, authToken);
HttpGet apiCountGet = new HttpGet(endpoint + resource);
apiCountGet.setHeader(APIConstants.API_KEY_AUTH, authToken);
URL url = new URL(endpoint);
Expand All @@ -10508,10 +10559,11 @@ public static CloseableHttpResponse getMarketplaceChatApiCount(String endpoint,
* @param uuid UUID of the API to be deleted
* @throws APIManagementException if an error occurs while deleting the API
*/
public static void marketplaceAssistantDeleteService(String endpoint, String authToken, String resource, String uuid)
throws APIManagementException {
public static void marketplaceAssistantDeleteService(String endpoint, String tokenEndpoint, String authToken,
String resource, String uuid) throws APIManagementException {

try {
String token = getAiTokenUsingKey(tokenEndpoint, authToken);
String resourceWithPathParam = endpoint + resource + "/{uuid}";
resourceWithPathParam = resourceWithPathParam.replace("{uuid}", uuid);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ public Response marketplaceAssistantExecute(MarketplaceAssistantRequestDTO marke
payload.put(APIConstants.HISTORY, history);
payload.put(APIConstants.TENANT_DOMAIN, organization);

String response = APIUtil.invokeAIService(configDto.getEndpoint(), configDto.getKey(),
configDto.getChatResource(), payload.toString(), null);
String response = APIUtil.invokeAIService(configDto.getEndpoint(), configDto.getTokenEndpoint(),
configDto.getKey(), configDto.getChatResource(), payload.toString(), null);

ObjectMapper objectMapper = new ObjectMapper();
MarketplaceAssistantResponseDTO executeResponseDTO = objectMapper.readValue(response,
Expand Down Expand Up @@ -118,6 +118,7 @@ public Response getMarketplaceAssistantApiCount(MessageContext messageContext) t

CloseableHttpResponse response = APIUtil.
getMarketplaceChatApiCount(configDto.getEndpoint(),
configDto.getTokenEndpoint(),
configDto.getKey(),
configDto.getApiCountResource());
int statusCode = response.getStatusLine().getStatusCode();
Expand Down

0 comments on commit f97c812

Please sign in to comment.