Skip to content

Commit

Permalink
AAP-32019: Enable Playbook generation/explanation endpoint in Lightsp…
Browse files Browse the repository at this point in the history
…eed service for on-prem
  • Loading branch information
manstis committed Nov 1, 2024
1 parent 6b61b6d commit 33f1e7d
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 161 deletions.
129 changes: 128 additions & 1 deletion ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import logging
import sys
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, cast

import backoff
import requests
from django.apps import apps
from django.conf import settings
from django_prometheus.conf import NAMESPACE
from health_check.exceptions import ServiceUnavailable
Expand Down Expand Up @@ -48,6 +49,10 @@
ModelPipelineContentMatch,
ModelPipelinePlaybookExplanation,
ModelPipelinePlaybookGeneration,
PlaybookExplanationParameters,
PlaybookExplanationResponse,
PlaybookGenerationParameters,
PlaybookGenerationResponse,
)
from ansible_ai_connect.ai.api.model_pipelines.wca.wca_utils import (
ContentMatchResponseChecks,
Expand Down Expand Up @@ -389,10 +394,132 @@ class WCABasePlaybookGenerationPipeline(
def __init__(self, inference_url):
super().__init__(inference_url=inference_url)

def invoke(self, params: PlaybookGenerationParameters) -> PlaybookGenerationResponse:
request = params.request
text = params.text
custom_prompt = params.custom_prompt
create_outline = params.create_outline
outline = params.outline
model_id = params.model_id
generation_id = params.generation_id

organization_id = request.user.organization.id if request.user.organization else None
api_key = self.get_api_key(request.user, organization_id)
model_id = self.get_model_id(request.user, organization_id, model_id)

headers = self.get_request_headers(api_key, generation_id)
data = {
"model_id": model_id,
"text": text,
"create_outline": create_outline,
}
if outline:
data["outline"] = outline
if custom_prompt:
if not custom_prompt.endswith("\n"):
custom_prompt = f"{custom_prompt}\n"
data["custom_prompt"] = custom_prompt

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_codegen_playbook,
)
@wca_codegen_playbook_hist.time()
def post_request():
return self.session.post(
f"{self._inference_url}/v1/wca/codegen/ansible/playbook",
headers=headers,
json=data,
verify=settings.ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL,
)

result = post_request()

x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
if generation_id and x_request_id:
# request/payload suggestion_id is a UUID not a string whereas
# HTTP headers are strings.
if x_request_id != str(generation_id):
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)

context = Context(model_id, result, False)
InferenceResponseChecks().run_checks(context)
result.raise_for_status()

response = json.loads(result.text)

playbook = response["playbook"]
outline = response["outline"]
warnings = response["warnings"] if "warnings" in response else []

from ansible_ai_connect.ai.apps import AiConfig

ai_config = cast(AiConfig, apps.get_app_config("ai"))
if ansible_lint_caller := ai_config.get_ansible_lint_caller():
playbook = ansible_lint_caller.run_linter(playbook)

return playbook, outline, warnings


class WCABasePlaybookExplanationPipeline(
WCABasePipeline, ModelPipelinePlaybookExplanation, metaclass=ABCMeta
):

def __init__(self, inference_url):
super().__init__(inference_url=inference_url)

def invoke(self, params: PlaybookExplanationParameters) -> PlaybookExplanationResponse:
request = params.request
content = params.content
custom_prompt = params.custom_prompt
model_id = params.model_id
explanation_id = params.explanation_id

organization_id = request.user.organization.id if request.user.organization else None
api_key = self.get_api_key(request.user, organization_id)
model_id = self.get_model_id(request.user, organization_id, model_id)

headers = self.get_request_headers(api_key, explanation_id)
data = {
"model_id": model_id,
"playbook": content,
}
if custom_prompt:
if not custom_prompt.endswith("\n"):
custom_prompt = f"{custom_prompt}\n"
data["custom_prompt"] = custom_prompt

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_explain_playbook,
)
@wca_explain_playbook_hist.time()
def post_request():
return self.session.post(
f"{self._inference_url}/v1/wca/explain/ansible/playbook",
headers=headers,
json=data,
verify=settings.ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL,
)

result = post_request()

x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
if explanation_id and x_request_id:
# request/payload suggestion_id is a UUID not a string whereas
# HTTP headers are strings.
if x_request_id != str(explanation_id):
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)

context = Context(model_id, result, False)
InferenceResponseChecks().run_checks(context)
result.raise_for_status()

response = json.loads(result.text)
return response["explanation"]
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def __init__(self, inference_url):
super().__init__(inference_url=inference_url)

def invoke(self, params: PlaybookGenerationParameters) -> PlaybookGenerationResponse:
raise FeatureNotAvailable
if settings.ENABLE_PLAYBOOK_ENDPOINT:
return super().invoke(params)
else:
raise FeatureNotAvailable

def self_test(self):
raise NotImplementedError
Expand All @@ -177,7 +180,10 @@ def __init__(self, inference_url):
super().__init__(inference_url=inference_url)

def invoke(self, params: PlaybookExplanationParameters) -> PlaybookExplanationResponse:
raise FeatureNotAvailable
if settings.ENABLE_PLAYBOOK_ENDPOINT:
return super().invoke(params)
else:
raise FeatureNotAvailable

def self_test(self):
raise NotImplementedError
143 changes: 14 additions & 129 deletions ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional

import backoff
from django.apps import apps
Expand All @@ -27,12 +26,12 @@
Suffixes,
WcaSecretManagerError,
)
from ansible_ai_connect.ai.api.exceptions import FeatureNotAvailable
from ansible_ai_connect.ai.api.model_pipelines.exceptions import (
WcaInferenceFailure,
WcaKeyNotFound,
WcaModelIdNotFound,
WcaNoDefaultModelId,
WcaRequestIdCorrelationFailure,
WcaTokenFailure,
)
from ansible_ai_connect.ai.api.model_pipelines.pipelines import (
Expand All @@ -54,12 +53,8 @@
WcaModelRequestException,
WcaTokenRequestException,
ibm_cloud_identity_token_hist,
wca_codegen_playbook_hist,
wca_explain_playbook_hist,
)
from ansible_ai_connect.ai.api.model_pipelines.wca.wca_utils import (
Context,
InferenceResponseChecks,
TokenContext,
TokenResponseChecks,
)
Expand Down Expand Up @@ -296,77 +291,14 @@ class WCASaaSPlaybookGenerationPipeline(WCASaaSPipeline, WCABasePlaybookGenerati
def __init__(self, inference_url):
super().__init__(inference_url=inference_url)

def self_test(self):
raise NotImplementedError

def invoke(self, params: PlaybookGenerationParameters) -> PlaybookGenerationResponse:
request = params.request
text = params.text
custom_prompt = params.custom_prompt
create_outline = params.create_outline
outline = params.outline
model_id = params.model_id
generation_id = params.generation_id

organization_id = request.user.organization.id if request.user.organization else None
api_key = self.get_api_key(request.user, organization_id)
model_id = self.get_model_id(request.user, organization_id, model_id)

headers = self.get_request_headers(api_key, generation_id)
data = {
"model_id": model_id,
"text": text,
"create_outline": create_outline,
}
if outline:
data["outline"] = outline
if custom_prompt:
if not custom_prompt.endswith("\n"):
custom_prompt = f"{custom_prompt}\n"
data["custom_prompt"] = custom_prompt

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_codegen_playbook,
)
@wca_codegen_playbook_hist.time()
def post_request():
return self.session.post(
f"{self._inference_url}/v1/wca/codegen/ansible/playbook",
headers=headers,
json=data,
verify=settings.ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL,
)

result = post_request()

x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
if generation_id and x_request_id:
# request/payload suggestion_id is a UUID not a string whereas
# HTTP headers are strings.
if x_request_id != str(generation_id):
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)

context = Context(model_id, result, False)
InferenceResponseChecks().run_checks(context)
result.raise_for_status()

response = json.loads(result.text)

playbook = response["playbook"]
outline = response["outline"]
warnings = response["warnings"] if "warnings" in response else []

from ansible_ai_connect.ai.apps import AiConfig

ai_config = cast(AiConfig, apps.get_app_config("ai"))
if ansible_lint_caller := ai_config.get_ansible_lint_caller():
playbook = ansible_lint_caller.run_linter(playbook)
if settings.ENABLE_PLAYBOOK_ENDPOINT:
return super().invoke(params)
else:
raise FeatureNotAvailable

return playbook, outline, warnings
def self_test(self):
raise NotImplementedError


@Register(api_type="wca")
Expand All @@ -375,58 +307,11 @@ class WCASaaSPlaybookExplanationPipeline(WCASaaSPipeline, WCABasePlaybookExplana
def __init__(self, inference_url):
super().__init__(inference_url=inference_url)

def self_test(self):
raise NotImplementedError

def invoke(self, params: PlaybookExplanationParameters) -> PlaybookExplanationResponse:
request = params.request
content = params.content
custom_prompt = params.custom_prompt
model_id = params.model_id
explanation_id = params.explanation_id

organization_id = request.user.organization.id if request.user.organization else None
api_key = self.get_api_key(request.user, organization_id)
model_id = self.get_model_id(request.user, organization_id, model_id)

headers = self.get_request_headers(api_key, explanation_id)
data = {
"model_id": model_id,
"playbook": content,
}
if custom_prompt:
if not custom_prompt.endswith("\n"):
custom_prompt = f"{custom_prompt}\n"
data["custom_prompt"] = custom_prompt
if settings.ENABLE_PLAYBOOK_ENDPOINT:
return super().invoke(params)
else:
raise FeatureNotAvailable

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_explain_playbook,
)
@wca_explain_playbook_hist.time()
def post_request():
return self.session.post(
f"{self._inference_url}/v1/wca/explain/ansible/playbook",
headers=headers,
json=data,
verify=settings.ANSIBLE_AI_MODEL_MESH_API_VERIFY_SSL,
)

result = post_request()

x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
if explanation_id and x_request_id:
# request/payload suggestion_id is a UUID not a string whereas
# HTTP headers are strings.
if x_request_id != str(explanation_id):
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)

context = Context(model_id, result, False)
InferenceResponseChecks().run_checks(context)
result.raise_for_status()

response = json.loads(result.text)
return response["explanation"]
def self_test(self):
raise NotImplementedError
Loading

0 comments on commit 33f1e7d

Please sign in to comment.