Skip to content

Commit

Permalink
Merge pull request #1193 from Tharsanan1/airl
Browse files Browse the repository at this point in the history
Add backend based airls from cp to dp
  • Loading branch information
CrowleyRajapakse authored Sep 26, 2024
2 parents 1f04e23 + 0702d6a commit 25e2be9
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 51 deletions.
2 changes: 1 addition & 1 deletion apim-apk-agent/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/pelletier/go-toml v1.9.5
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.9.0
github.com/wso2/apk/common-go-libs v0.0.0-20240920041902-85449a1c0150
github.com/wso2/apk/common-go-libs v0.0.0-20240923143402-ff7fdb0366f9
google.golang.org/grpc v1.62.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v2 v2.4.0
Expand Down
4 changes: 2 additions & 2 deletions apim-apk-agent/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ github.com/vektah/gqlparser v1.3.1 h1:8b0IcD3qZKWJQHSzynbDlrtP3IxVydZ2DZepCGofqf
github.com/vektah/gqlparser v1.3.1/go.mod h1:bkVf0FX+Stjg/MHnm8mEyubuaArhNEqfQhF+OTiAL74=
github.com/wso2/apk/adapter v0.0.0-20240408123538-86a74d977eee h1:g0ivVkzybfcEkB0vBGTAXTUuMZpsF3zOTVtAgmW851s=
github.com/wso2/apk/adapter v0.0.0-20240408123538-86a74d977eee/go.mod h1:xYS5auF/YxnyRykw7NBSn/YR2FHD4hTeyav4Nhec8d0=
github.com/wso2/apk/common-go-libs v0.0.0-20240920041902-85449a1c0150 h1:X3OezAh2UOxmQIRxsAua87nNqmoIGXx1yfQIvc4a+G4=
github.com/wso2/apk/common-go-libs v0.0.0-20240920041902-85449a1c0150/go.mod h1:SbZVA1jeiVG9dqk9fGcY/bB0JgEaQgtXqFAlxAfN0Lk=
github.com/wso2/apk/common-go-libs v0.0.0-20240923143402-ff7fdb0366f9 h1:MwQqG+/ODDIfLfc3xNMYk6jM+hB2ttjwZnaDBeiMOJI=
github.com/wso2/apk/common-go-libs v0.0.0-20240923143402-ff7fdb0366f9/go.mod h1:SbZVA1jeiVG9dqk9fGcY/bB0JgEaQgtXqFAlxAfN0Lk=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
Expand Down
4 changes: 2 additions & 2 deletions apim-apk-agent/internal/eventhub/dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"strconv"
"time"

dpv1alpha2 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha2"
dpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha3"
"github.com/wso2/product-apim-tooling/apim-apk-agent/config"
internalk8sClient "github.com/wso2/product-apim-tooling/apim-apk-agent/internal/k8sClient"
logger "github.com/wso2/product-apim-tooling/apim-apk-agent/internal/loggers"
Expand Down Expand Up @@ -240,7 +240,7 @@ func FetchAPIsOnStartUp(conf *config.Config, k8sClient client.Client) {
if err != nil {
logger.LoggerEventhub.Errorf("Error occurred while fetching APIs from control plane %v", err)
}
removeApis := make([]dpv1alpha2.API, 0)
removeApis := make([]dpv1alpha3.API, 0)
for _, k8sAPI := range k8sAPIS {
found := false
if apis != nil {
Expand Down
74 changes: 67 additions & 7 deletions apim-apk-agent/internal/k8sClient/k8s_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import (
)

// DeployAPICR applies the given API struct to the Kubernetes cluster.
func DeployAPICR(api *dpv1alpha2.API, k8sClient client.Client) {
crAPI := &dpv1alpha2.API{}
func DeployAPICR(api *dpv1alpha3.API, k8sClient client.Client) {
crAPI := &dpv1alpha3.API{}
if err := k8sClient.Get(context.Background(), client.ObjectKey{Namespace: api.ObjectMeta.Namespace, Name: api.Name}, crAPI); err != nil {
if !k8error.IsNotFound(err) {
loggers.LoggerK8sClient.Error("Unable to get API CR: " + err.Error())
Expand All @@ -66,7 +66,7 @@ func DeployAPICR(api *dpv1alpha2.API, k8sClient client.Client) {
}

// UndeployK8sAPICR removes the API Custom Resource from the Kubernetes cluster based on API ID label.
func UndeployK8sAPICR(k8sClient client.Client, k8sAPI dpv1alpha2.API) error {
func UndeployK8sAPICR(k8sClient client.Client, k8sAPI dpv1alpha3.API) error {
err := k8sClient.Delete(context.Background(), &k8sAPI, &client.DeleteOptions{})
if err != nil {
loggers.LoggerK8sClient.Errorf("Unable to delete API CR: %v", err)
Expand All @@ -82,7 +82,7 @@ func UndeployAPICR(apiID string, k8sClient client.Client) {
if errReadConfig != nil {
loggers.LoggerK8sClient.Errorf("Error reading configurations: %v", errReadConfig)
}
apiList := &dpv1alpha2.APIList{}
apiList := &dpv1alpha3.APIList{}
err := k8sClient.List(context.Background(), apiList, &client.ListOptions{Namespace: conf.DataPlane.Namespace, LabelSelector: labels.SelectorFromSet(map[string]string{"apiUUID": apiID})})
// Retrieve all API CRs from the Kubernetes cluster
if err != nil {
Expand Down Expand Up @@ -429,6 +429,66 @@ func DeploySubscriptionRateLimitPolicyCR(policy eventhubTypes.SubscriptionPolicy

}

// DeployAIRateLimitPolicyCR applies the given AIRateLimitPolicies struct to the Kubernetes cluster.
func DeployAIRateLimitPolicyCR(policy eventhubTypes.SubscriptionPolicy, k8sClient client.Client) {
conf, _ := config.ReadConfigs()
tokenCount := &dpv1alpha3.TokenCount{}
requestCount := &dpv1alpha3.RequestCount{}
if policy.DefaultLimit.AiAPIQuota.PromptTokenCount != nil &&
policy.DefaultLimit.AiAPIQuota.CompletionTokenCount != nil &&
policy.DefaultLimit.AiAPIQuota.TotalTokenCount != nil {
tokenCount = &dpv1alpha3.TokenCount{
Unit: policy.DefaultLimit.AiAPIQuota.TimeUnit,
RequestTokenCount: uint32(*policy.DefaultLimit.AiAPIQuota.PromptTokenCount),
ResponseTokenCount: uint32(*policy.DefaultLimit.AiAPIQuota.CompletionTokenCount),
TotalTokenCount: uint32(*policy.DefaultLimit.AiAPIQuota.TotalTokenCount),
}
} else {
tokenCount = nil
}
if policy.DefaultLimit.AiAPIQuota.RequestCount != nil {
requestCount = &dpv1alpha3.RequestCount{
RequestsPerUnit: uint32(*policy.DefaultLimit.AiAPIQuota.RequestCount),
Unit: policy.DefaultLimit.AiAPIQuota.TimeUnit,
}
} else {
requestCount = nil
}

crRateLimitPolicies := dpv1alpha3.AIRateLimitPolicy{
ObjectMeta: metav1.ObjectMeta{Name: policy.Name,
Namespace: conf.DataPlane.Namespace,
},
Spec: dpv1alpha3.AIRateLimitPolicySpec{
Override: &dpv1alpha3.AIRateLimit{
Organization: policy.TenantDomain,
TokenCount: tokenCount,
RequestCount: requestCount,
},
TargetRef: gwapiv1b1.PolicyTargetReference{Group: constants.GatewayGroup, Kind: "Subscription", Name: "default"},
},
}
crRateLimitPolicyFetched := &dpv1alpha3.AIRateLimitPolicy{}
if err := k8sClient.Get(context.Background(), client.ObjectKey{Namespace: crRateLimitPolicies.ObjectMeta.Namespace, Name: crRateLimitPolicies.Name}, crRateLimitPolicyFetched); err != nil {
if !k8error.IsNotFound(err) {
loggers.LoggerK8sClient.Error("Unable to get AiratelimitPolicy CR: " + err.Error())
}
if err := k8sClient.Create(context.Background(), &crRateLimitPolicies); err != nil {
loggers.LoggerK8sClient.Error("Unable to create AIRateLimitPolicies CR: " + err.Error())
} else {
loggers.LoggerK8sClient.Info("AIRateLimitPolicies CR created: " + crRateLimitPolicies.Name)
}
} else {
crRateLimitPolicyFetched.Spec = crRateLimitPolicies.Spec
crRateLimitPolicyFetched.ObjectMeta.Labels = crRateLimitPolicies.ObjectMeta.Labels
if err := k8sClient.Update(context.Background(), crRateLimitPolicyFetched); err != nil {
loggers.LoggerK8sClient.Error("Unable to update AiRatelimitPolicy CR: " + err.Error())
} else {
loggers.LoggerK8sClient.Info("AiRatelimitPolicy CR updated: " + crRateLimitPolicyFetched.Name)
}
}
}

// DeployBackendCR applies the given Backends struct to the Kubernetes cluster.
func DeployBackendCR(backends *dpv1alpha2.Backend, k8sClient client.Client) {
crBackends := &dpv1alpha2.Backend{}
Expand Down Expand Up @@ -625,10 +685,10 @@ func getSha1Value(input string) string {
}

// RetrieveAllAPISFromK8s retrieves all the API CRs from the Kubernetes cluster
func RetrieveAllAPISFromK8s(k8sClient client.Client, nextToken string) ([]dpv1alpha2.API, string, error) {
func RetrieveAllAPISFromK8s(k8sClient client.Client, nextToken string) ([]dpv1alpha3.API, string, error) {
conf, _ := config.ReadConfigs()
apiList := dpv1alpha2.APIList{}
resolvedAPIList := make([]dpv1alpha2.API, 0)
apiList := dpv1alpha3.APIList{}
resolvedAPIList := make([]dpv1alpha3.API, 0)
var err error
if nextToken == "" {
err = k8sClient.List(context.Background(), &apiList, &client.ListOptions{Namespace: conf.DataPlane.Namespace})
Expand Down
87 changes: 59 additions & 28 deletions apim-apk-agent/internal/synchronizer/ratelimit_policy_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func FetchRateLimitPoliciesOnEvent(ratelimitName string, organization string, c

// FetchSubscriptionRateLimitPoliciesOnEvent fetches the policies from the control plane on the start up and notification event updates
func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organization string, c client.Client) {
logger.LoggerSynchronizer.Info("Fetching RateLimit Policies from Control Plane.")
logger.LoggerSynchronizer.Info("Fetching Subscription RateLimit Policies from Control Plane.")

// Read configurations and derive the eventHub details
conf, errReadConfig := config.ReadConfigs()
Expand All @@ -189,7 +189,7 @@ func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organizatio
}
}

logger.LoggerSynchronizer.Infof("Fetching RateLimit Policies from the URL %v: ", ehURL)
logger.LoggerSynchronizer.Infof("Fetching Subscription RateLimit Policies from the URL %v: ", ehURL)

ehUname := ehConfigs.Username
ehPass := ehConfigs.Password
Expand All @@ -201,19 +201,9 @@ func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organizatio
// Create a HTTP request
req, err := http.NewRequest("GET", ehURL, nil)
if err != nil {
logger.LoggerSynchronizer.Errorf("Error while creating http request for RateLimit Policies Endpoint : %v", err)
logger.LoggerSynchronizer.Errorf("Error while creating http request for Subscription RateLimit Policies Endpoint : %v", err)
}

var queryParamMap map[string]string

if queryParamMap != nil && len(queryParamMap) > 0 {
q := req.URL.Query()
// Making necessary query parameters for the request
for queryParamKey, queryParamValue := range queryParamMap {
q.Add(queryParamKey, queryParamValue)
}
req.URL.RawQuery = q.Encode()
}
// Setting authorization header
req.Header.Set(sync.Authorization, basicAuth)

Expand All @@ -231,45 +221,74 @@ func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organizatio
var errorMsg string
if err != nil {
errorMsg = "Error occurred while calling the REST API: " + policiesEndpoint
go retryRLPFetchData(conf, errorMsg, err, c)
go retrySubscriptionRLPFetchData(conf, errorMsg, err, c)
return
}
responseBytes, err := ioutil.ReadAll(resp.Body)
logger.LoggerSynchronizer.Debugf("Response String received for Policies: %v", string(responseBytes))

if err != nil {
errorMsg = "Error occurred while reading the response received for: " + policiesEndpoint
go retryRLPFetchData(conf, errorMsg, err, c)
go retrySubscriptionRLPFetchData(conf, errorMsg, err, c)
return
}

if resp.StatusCode == http.StatusOK {
var rateLimitPolicyList eventhubTypes.SubscriptionPolicyList
err := json.Unmarshal(responseBytes, &rateLimitPolicyList)
if err != nil {
logger.LoggerSynchronizer.Errorf("Error occurred while unmarshelling RateLimit Policies event data %v", err)
logger.LoggerSynchronizer.Errorf("Error occurred while unmarshelling Subscription RateLimit Policies event data %v", err)
return
}
logger.LoggerSynchronizer.Debugf("Policies received: %v", rateLimitPolicyList.List)
var rateLimitPolicies []eventhubTypes.SubscriptionPolicy = rateLimitPolicyList.List
for _, policy := range rateLimitPolicies {
if policy.DefaultLimit.RequestCount.TimeUnit == "min" {
policy.DefaultLimit.RequestCount.TimeUnit = "Minute"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "hours" {
policy.DefaultLimit.RequestCount.TimeUnit = "Hour"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "days" {
policy.DefaultLimit.RequestCount.TimeUnit = "Day"
if policy.QuotaType == "aiApiQuota" {
if policy.DefaultLimit.AiAPIQuota != nil {
switch policy.DefaultLimit.AiAPIQuota.TimeUnit {
case "min":
policy.DefaultLimit.AiAPIQuota.TimeUnit = "Minute"
case "hours":
policy.DefaultLimit.AiAPIQuota.TimeUnit = "Hour"
case "days":
policy.DefaultLimit.AiAPIQuota.TimeUnit = "Day"
default:
logger.LoggerSynchronizer.Errorf("Unsupported timeunit %s", policy.DefaultLimit.AiAPIQuota.TimeUnit)
continue
}
if policy.DefaultLimit.AiAPIQuota.PromptTokenCount == nil && policy.DefaultLimit.AiAPIQuota.TotalTokenCount != nil {
policy.DefaultLimit.AiAPIQuota.PromptTokenCount = policy.DefaultLimit.AiAPIQuota.TotalTokenCount
}
if policy.DefaultLimit.AiAPIQuota.CompletionTokenCount == nil && policy.DefaultLimit.AiAPIQuota.TotalTokenCount != nil {
policy.DefaultLimit.AiAPIQuota.CompletionTokenCount = policy.DefaultLimit.AiAPIQuota.TotalTokenCount
}
if policy.DefaultLimit.AiAPIQuota.TotalTokenCount == nil && policy.DefaultLimit.AiAPIQuota.PromptTokenCount != nil && policy.DefaultLimit.AiAPIQuota.CompletionTokenCount != nil {
total := *policy.DefaultLimit.AiAPIQuota.PromptTokenCount + *policy.DefaultLimit.AiAPIQuota.CompletionTokenCount
policy.DefaultLimit.AiAPIQuota.TotalTokenCount = &total
}
managementserver.AddSubscriptionPolicy(policy)
k8sclient.DeployAIRateLimitPolicyCR(policy, c)
} else {
logger.LoggerSynchronizer.Errorf("AIQuota type response recieved but no data found. %+v", policy.DefaultLimit)
}
} else {
if policy.DefaultLimit.RequestCount.TimeUnit == "min" {
policy.DefaultLimit.RequestCount.TimeUnit = "Minute"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "hours" {
policy.DefaultLimit.RequestCount.TimeUnit = "Hour"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "days" {
policy.DefaultLimit.RequestCount.TimeUnit = "Day"
}
managementserver.AddSubscriptionPolicy(policy)
logger.LoggerSynchronizer.Infof("RateLimit Policy added to internal map: %v", policy)
// Update the exisitng rate limit policies with current policy
k8sclient.DeploySubscriptionRateLimitPolicyCR(policy, c)
}
managementserver.AddSubscriptionPolicy(policy)
logger.LoggerSynchronizer.Infof("RateLimit Policy added to internal map: %v", policy)
// Update the exisitng rate limit policies with current policy
k8sclient.DeploySubscriptionRateLimitPolicyCR(policy, c)

}
} else {
errorMsg = "Failed to fetch data! " + policiesEndpoint + " responded with " +
strconv.Itoa(resp.StatusCode)
go retryRLPFetchData(conf, errorMsg, err, c)
go retrySubscriptionRLPFetchData(conf, errorMsg, err, c)
}
}

Expand All @@ -284,3 +303,15 @@ func retryRLPFetchData(conf *config.Config, errorMessage string, err error, c cl
return
}
}

func retrySubscriptionRLPFetchData(conf *config.Config, errorMessage string, err error, c client.Client) {
logger.LoggerSynchronizer.Debugf("Time Duration for retrying: %v",
conf.ControlPlane.RetryInterval*time.Second)
time.Sleep(conf.ControlPlane.RetryInterval * time.Second)
FetchSubscriptionRateLimitPoliciesOnEvent("", "", c)
retryAttempt++
if retryAttempt >= retryCount {
logger.LoggerSynchronizer.Errorf(errorMessage, err)
return
}
}
13 changes: 12 additions & 1 deletion apim-apk-agent/pkg/eventhub/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ type ConditionGroup struct {

// DefaultLimit represents the default limit within the response.
type DefaultLimit struct {
QuotaType string `json:"quotaType"`
AiAPIQuota *AiAPIQuota `json:"aiApiQuota"`
QuotaType string `json:"quotaType"`
RequestCount struct {
TimeUnit string `json:"timeUnit"`
UnitTime int `json:"unitTime"`
Expand All @@ -206,6 +207,16 @@ type DefaultLimit struct {
EventCount interface{} `json:"eventCount"`
}

// AiAPIQuota contains the AI ratelimit configurations
type AiAPIQuota struct {
CompletionTokenCount *int `json:"completionTokenCount"`
PromptTokenCount *int `json:"promptTokenCount"`
RequestCount *int `json:"requestCount"`
TimeUnit string `json:"timeUnit"`
TotalTokenCount *int `json:"totalTokenCount"`
UnitTime int `json:"unitTime"`
}

// Scope for struct Scope
type Scope struct {
Name string `json:"name"`
Expand Down
2 changes: 1 addition & 1 deletion apim-apk-agent/pkg/synchronizer/apis_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func SendRequestToControlPlane(req *http.Request, apiID *string, gwLabels []stri
if apiID != nil {
logger.LoggerSync.Debugf("Sending the control plane request for the API: %q", *apiID)
} else {
logger.LoggerSync.Debug("Sending the control plane request")
logger.LoggerSync.Debugf("Sending the control plane request, url: %s", req.URL.String())
}
resp, err := client.Do(req)

Expand Down
41 changes: 41 additions & 0 deletions apim-apk-agent/pkg/transformer/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ type APIMApi struct {
APIThrottlingPolicy string `yaml:"apiThrottlingPolicy"`
APIPolicies APIMOperationPolicies `yaml:"apiPolicies"`
AIConfiguration APIMAIConfiguration `yaml:"aiConfiguration"`
MaxTps *MaxTps `yaml:"maxTps"`
}

// APIMAIConfiguration holds the configuration details for AI providers
Expand All @@ -192,6 +193,46 @@ type APIYaml struct {
Data APIMApi `json:"data"`
}

// MaxTps represents the maximum transactions per second (TPS) settings for both
// production and sandbox environments. It also includes an optional configuration
// for token-based throttling.
//
// Fields:
// - Production: Maximum TPS for the production environment.
// - ProductionTimeUnit: The time unit for the production TPS limit (e.g., seconds, minutes).
// - Sandbox: Maximum TPS for the sandbox environment.
// - SandboxTimeUnit: The time unit for the sandbox TPS limit.
// - TokenBasedThrottlingConfiguration: Configuration for token-based throttling.
type MaxTps struct {
Production *int `yaml:"production"`
ProductionTimeUnit *string `yaml:"productionTimeUnit"`
Sandbox *int `yaml:"sandbox"`
SandboxTimeUnit *string `yaml:"sandboxTimeUnit"`
TokenBasedThrottlingConfiguration *TokenBasedThrottlingConfig `yaml:"tokenBasedThrottlingConfiguration"`
}

// TokenBasedThrottlingConfig defines the token-based throttling limits for
// both production and sandbox environments. Token-based throttling places
// a limit on the number of prompt and completion tokens that can be used.
//
// Fields:
// - ProductionMaxPromptTokenCount: Maximum number of prompt tokens for production.
// - ProductionMaxCompletionTokenCount: Maximum number of completion tokens for production.
// - ProductionMaxTotalTokenCount: Maximum total token count (prompt + completion) for production.
// - SandboxMaxPromptTokenCount: Maximum number of prompt tokens for sandbox.
// - SandboxMaxCompletionTokenCount: Maximum number of completion tokens for sandbox.
// - SandboxMaxTotalTokenCount: Maximum total token count (prompt + completion) for sandbox.
// - IsTokenBasedThrottlingEnabled: Flag to enable or disable token-based throttling.
type TokenBasedThrottlingConfig struct {
ProductionMaxPromptTokenCount *int `yaml:"productionMaxPromptTokenCount"`
ProductionMaxCompletionTokenCount *int `yaml:"productionMaxCompletionTokenCount"`
ProductionMaxTotalTokenCount *int `yaml:"productionMaxTotalTokenCount"`
SandboxMaxPromptTokenCount *int `yaml:"sandboxMaxPromptTokenCount"`
SandboxMaxCompletionTokenCount *int `yaml:"sandboxMaxCompletionTokenCount"`
SandboxMaxTotalTokenCount *int `yaml:"sandboxMaxTotalTokenCount"`
IsTokenBasedThrottlingEnabled *bool `yaml:"isTokenBasedThrottlingEnabled"`
}

// APIArtifact represents the artifact details of an API, including api details, environment configuration,
// Swagger definition, deployment descriptor, and revision ID extracted from the API Project Zip.
type APIArtifact struct {
Expand Down
Loading

0 comments on commit 25e2be9

Please sign in to comment.