-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add fal tool for image generation
- Loading branch information
Showing
13 changed files
with
738 additions
and
23 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import requests | ||
|
||
from core.tools.errors import ToolProviderCredentialValidationError | ||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController | ||
|
||
|
||
class FalProvider(BuiltinToolProviderController): | ||
def _validate_credentials(self, credentials: dict) -> None: | ||
url = "https://queue.fal.run/fal-ai/flux/schnell" | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Key {credentials.get('fal_api_key')}", | ||
} | ||
data = {"prompt": "cute girl, blue eyes, white hair, anime style."} | ||
response = requests.post(url, headers=headers, data=data) | ||
if response.status_code != 200: | ||
raise ToolProviderCredentialValidationError("Fal API key is invalid") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
identity: | ||
author: zhuhao | ||
name: fal | ||
label: | ||
en_US: Fal | ||
zh_CN: Fal | ||
description: | ||
en_US: The image generation API provided by fal.ai includes Flux, AuraFlow, Stable Diffusion and Kolors models. | ||
zh_CN: Fal提供的图片生成API, 包含Flux,AuraFlow,Stable Diffusion和Kolors模型。 | ||
icon: icon.svg | ||
tags: | ||
- image | ||
credentials_for_provider: | ||
fal_api_key: | ||
type: secret-input | ||
required: true | ||
label: | ||
en_US: Fal API Key | ||
placeholder: | ||
en_US: Please input your Fal API key | ||
url: https://fal.ai/dashboard/keys |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any, Union | ||
|
||
import fal_client | ||
|
||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
|
||
class AuraFlowTool(BuiltinTool): | ||
""" | ||
A tool for generating image via Fal.ai AuraFlow model | ||
""" | ||
|
||
def _invoke( | ||
self, user_id: str, tool_parameters: dict[str, Any] | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
api_key = self.runtime.credentials.get("fal_api_key", "") | ||
if not api_key: | ||
return self.create_text_message("Please input fal api key") | ||
client = fal_client.SyncClient(key=api_key) | ||
payload = { | ||
"prompt": tool_parameters.get("prompt"), | ||
"num_images": tool_parameters.get("num_images", 1), | ||
"seed": tool_parameters.get("seed"), | ||
"guidance_scale": tool_parameters.get("guidance_scale", 3.5), | ||
"num_inference_steps": tool_parameters.get("num_inference_steps", 50), | ||
"expand_prompt": tool_parameters.get("expand_prompt", True), | ||
} | ||
handler = client.submit("fal-ai/aura-flow", arguments=payload) | ||
request_id = handler.request_id | ||
res = client.result("fal-ai/aura-flow", request_id) | ||
result = [] | ||
for image in res.get("images", []): | ||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
identity: | ||
name: auraflow | ||
author: zhuhao | ||
label: | ||
en_US: AuraFlow | ||
icon: icon.svg | ||
description: | ||
human: | ||
en_US: Generate image via Fal.ai AuraFlow model. | ||
llm: This tool is used to generate image from prompt via Fal.ai AuraFlow model. | ||
parameters: | ||
- name: prompt | ||
type: string | ||
required: true | ||
label: | ||
en_US: prompt | ||
zh_Hans: 提示词 | ||
human_description: | ||
en_US: The text prompt used to generate the image. | ||
zh_Hans: 建议用英文的生成图片提示词以获得更好的生成效果。 | ||
llm_description: this prompt text will be used to generate image. | ||
form: llm | ||
- name: num_inference_steps | ||
type: number | ||
required: true | ||
default: 28 | ||
min: 1 | ||
max: 100 | ||
label: | ||
en_US: Num Inference Steps | ||
zh_Hans: 生成图片的步数 | ||
form: form | ||
human_description: | ||
en_US: The number of inference steps to perform. More steps produce higher quality but take longer. | ||
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 | ||
- name: seed | ||
type: number | ||
min: 0 | ||
max: 9999999999 | ||
label: | ||
en_US: Seed | ||
zh_Hans: 种子 | ||
human_description: | ||
en_US: The same seed and prompt can produce similar images. | ||
zh_Hans: 相同的种子和提示可以产生相似的图像。 | ||
form: form | ||
- name: guidance_scale | ||
type: number | ||
required: false | ||
default: 3.5 | ||
label: | ||
en_US: Guidance Scale | ||
zh_Hans: 引导缩放比例 | ||
form: form | ||
human_description: | ||
en_US: The Guidance scale is a measure of how close you want the model to stick to your prompt when looking for a related image to you. | ||
zh_Hans: 无分类器引导比例是衡量在寻找相关图像展示时最贴近提示的一个尺度. | ||
- name: num_images | ||
type: number | ||
default: 1 | ||
label: | ||
en_US: Image Number | ||
zh_Hans: 图像数量 | ||
human_description: | ||
en_US: The number of images to generate | ||
zh_Hans: 生成图像的数量 | ||
form: form | ||
- name: expand_prompt | ||
type: boolean | ||
default: true | ||
label: | ||
en_US: prompt expansion | ||
zh_Hans: 是否进行提示词扩展 | ||
human_description: | ||
en_US: Whether to perform prompt expansion | ||
zh_Hans: 是否进行提示词扩展 | ||
llm_description: Perform prompt expansion if true. | ||
form: llm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Any, Union | ||
|
||
import fal_client | ||
|
||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
FLUX_MODEL = {"schnell": "fal-ai/flux/schnell", "dev": "fal-ai/flux/dev", "pro": ""} | ||
|
||
|
||
class FluxTool(BuiltinTool): | ||
""" | ||
A tool for generating image via Fal.ai Flux model | ||
""" | ||
|
||
def _invoke( | ||
self, user_id: str, tool_parameters: dict[str, Any] | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
api_key = self.runtime.credentials.get("fal_api_key", "") | ||
if not api_key: | ||
return self.create_text_message("Please input fal api key") | ||
client = fal_client.SyncClient(key=api_key) | ||
model = tool_parameters.get("model", "schnell") | ||
model_name = FLUX_MODEL.get(model) | ||
payload = { | ||
"prompt": tool_parameters.get("prompt"), | ||
"image_size": tool_parameters.get("image_size", "landscape_4_3"), | ||
"num_images": tool_parameters.get("num_images", 1), | ||
"seed": tool_parameters.get("seed"), | ||
"num_inference_steps": tool_parameters.get("num_inference_steps", 4), | ||
"sync_mode": tool_parameters.get("sync_mode", True), | ||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True), | ||
} | ||
handler = client.submit(model_name, arguments=payload) | ||
request_id = handler.request_id | ||
res = client.result(model_name, request_id) | ||
result = [] | ||
for image in res.get("images", []): | ||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
identity: | ||
name: flux | ||
author: zhuhao | ||
label: | ||
en_US: Flux | ||
icon: icon.svg | ||
description: | ||
human: | ||
en_US: Generate image via Fal.ai flux model. | ||
llm: This tool is used to generate image from prompt via Fal.ai flux model. | ||
parameters: | ||
- name: model | ||
type: select | ||
required: true | ||
options: | ||
- value: schnell | ||
label: | ||
en_US: Flux.1-schnell | ||
- value: dev | ||
label: | ||
en_US: Flux.1-dev | ||
- value: pro | ||
label: | ||
en_US: Flux.1-pro | ||
default: schnell | ||
label: | ||
en_US: Choose Image Model | ||
zh_Hans: 选择生成图片的模型 | ||
form: form | ||
- name: prompt | ||
type: string | ||
required: true | ||
label: | ||
en_US: prompt | ||
zh_Hans: 提示词 | ||
human_description: | ||
en_US: The text prompt used to generate the image. | ||
zh_Hans: 建议用英文的生成图片提示词以获得更好的生成效果。 | ||
llm_description: this prompt text will be used to generate image. | ||
form: llm | ||
- name: image_size | ||
type: select | ||
required: true | ||
options: | ||
- value: square_hd | ||
label: | ||
en_US: square_hd | ||
- value: square | ||
label: | ||
en_US: square | ||
- value: portrait_4_3 | ||
label: | ||
en_US: portrait_4_3 | ||
- value: portrait_16_9 | ||
label: | ||
en_US: portrait_16_9 | ||
- value: landscape_4_3 | ||
label: | ||
en_US: landscape_4_3 | ||
- value: landscape_16_9 | ||
label: | ||
en_US: landscape_16_9 | ||
default: landscape_4_3 | ||
label: | ||
en_US: The size of the generated image | ||
zh_Hans: 选择生成的图片大小 | ||
form: form | ||
- name: num_inference_steps | ||
type: number | ||
required: true | ||
default: 28 | ||
min: 1 | ||
max: 100 | ||
label: | ||
en_US: Num Inference Steps | ||
zh_Hans: 生成图片的步数 | ||
form: form | ||
human_description: | ||
en_US: The number of inference steps to perform. More steps produce higher quality but take longer. | ||
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 | ||
- name: seed | ||
type: number | ||
min: 0 | ||
max: 9999999999 | ||
label: | ||
en_US: Seed | ||
zh_Hans: 种子 | ||
human_description: | ||
en_US: The same seed and prompt can produce similar images. | ||
zh_Hans: 相同的种子和提示可以产生相似的图像。 | ||
form: form | ||
- name: guidance_scale | ||
type: number | ||
required: false | ||
default: 3.5 | ||
label: | ||
en_US: Guidance Scale | ||
zh_Hans: 引导缩放比例 | ||
form: form | ||
human_description: | ||
en_US: The Guidance scale is a measure of how close you want the model to stick to your prompt when looking for a related image to you. | ||
zh_Hans: 无分类器引导比例是衡量在寻找相关图像展示时最贴近提示的一个尺度. | ||
- name: num_images | ||
type: number | ||
default: 1 | ||
label: | ||
en_US: Image Number | ||
zh_Hans: 图像数量 | ||
human_description: | ||
en_US: The number of images to generate | ||
zh_Hans: 生成图像的数量 | ||
form: form | ||
- name: sync_mode | ||
type: boolean | ||
default: false | ||
label: | ||
en_US: Sync Mode | ||
zh_Hans: 同步模式 | ||
human_description: | ||
en_US: The function will wait for the image to be generated and uploaded before returning the response if true | ||
zh_Hans: 为真时该函数将在返回响应之前等待图像生成并上传. | ||
form: form | ||
- name: enable_safety_checker | ||
type: boolean | ||
default: true | ||
label: | ||
en_US: Enable Safety Checker | ||
zh_Hans: 开启安全检测 | ||
human_description: | ||
en_US: The safety checker will be enabled if true | ||
zh_Hans: 为真时开启安全检测. | ||
form: form |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Any, Union | ||
|
||
import fal_client | ||
|
||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
|
||
class AuraFlowTool(BuiltinTool): | ||
""" | ||
A tool for generating image via Fal.ai AuraFlow model | ||
""" | ||
|
||
def _invoke( | ||
self, user_id: str, tool_parameters: dict[str, Any] | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
api_key = self.runtime.credentials.get("fal_api_key", "") | ||
if not api_key: | ||
return self.create_text_message("Please input fal api key") | ||
client = fal_client.SyncClient(key=api_key) | ||
payload = { | ||
"prompt": tool_parameters.get("prompt"), | ||
"negative_prompt": tool_parameters.get("negative_prompt", ""), | ||
"image_size": tool_parameters.get("image_size", "square_hd"), | ||
"num_images": tool_parameters.get("num_images", 1), | ||
"seed": tool_parameters.get("seed"), | ||
"guidance_scale": tool_parameters.get("guidance_scale", 5), | ||
"num_inference_steps": tool_parameters.get("num_inference_steps", 50), | ||
"sync_mode": tool_parameters.get("sync_mode", True), | ||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True), | ||
} | ||
handler = client.submit("fal-ai/kolors", arguments=payload) | ||
request_id = handler.request_id | ||
res = client.result("fal-ai/kolors", request_id) | ||
result = [] | ||
for image in res.get("images", []): | ||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) | ||
return result |
Oops, something went wrong.