Skip to content

Commit

Permalink
Update dall-e openai support (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored May 17, 2024
1 parent 4bdb995 commit cbc913b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 68 deletions.
65 changes: 31 additions & 34 deletions examples/llm/openai/thread/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,69 @@ package main
import (
"context"
"fmt"
"strings"

"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/transformer"
)

type Answer struct {
Answer string `json:"answer" jsonschema:"description=the pirate answer"`
type Image struct {
Description string `json:"description" jsonschema:"description=the description of the image that should be created"`
}

func getAnswer(a Answer) string {
return "🦜 ☠️ " + a.Answer
func crateImage(i Image) string {
d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize512x512)
imageURL, err := d.Transform(context.Background(), i.Description)
if err != nil {
return fmt.Errorf("error creating image: %w", err).Error()
}

fmt.Println("Image created with url:", imageURL)

return imageURL.(string)
}

func newStr(str string) *string {
return &str
}

func main() {
openaillm := openai.New()
openaillm.WithToolChoice(newStr("getPirateAnswer"))
openaillm := openai.New().WithModel(openai.GPT4o)
openaillm.WithToolChoice(newStr("auto"))
err := openaillm.BindFunction(
getAnswer,
"getPirateAnswer",
"use this function to get the pirate answer",
crateImage,
"createImage",
"use this function to create an image from a description",
)
if err != nil {
panic(err)
}

t := thread.New().AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("Hello, I'm a user"),
).AddContent(
thread.NewTextContent("Can you greet me?"),
),
).AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("please greet me as a pirate."),
thread.NewTextContent("Please, create an image that inspires you"),
),
)

fmt.Println(t)

err = openaillm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

t.AddMessage(thread.NewUserMessage().AddContent(
thread.NewTextContent("now translate to italian as a poem"),
))
if t.LastMessage().Role == thread.RoleTool {
t.AddMessage(thread.NewUserMessage().AddContent(
thread.NewImageContentFromURL(
strings.ReplaceAll(t.LastMessage().Contents[0].AsToolResponseData().Result, `"`, ""),
),
).AddContent(
thread.NewTextContent("can you describe the image?"),
))

fmt.Println(t)
// disable functions
openaillm.WithToolChoice(nil)
openaillm.WithStream(true, func(a string) {
if a == openai.EOS {
fmt.Printf("\n")
return
err = openaillm.Generate(context.Background(), t)
if err != nil {
panic(err)
}
fmt.Printf("%s", a)
})

err = openaillm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t)
Expand Down
7 changes: 5 additions & 2 deletions examples/transformer/dalle/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/transformer"
)

func main() {

d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize1024).AsFile("test.png")
d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize1024x1024)

_, err := d.Transform(context.Background(), "a goose working with pipelines")
imageURL, err := d.Transform(context.Background(), "a goose working with pipelines")
if err != nil {
panic(err)
}

fmt.Println("Image created:", imageURL)
}
13 changes: 9 additions & 4 deletions llm/openai/function.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package openai

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"

"github.com/invopop/jsonschema"
"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -187,11 +189,14 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er

// Marshal the function result to JSON
if len(result) > 0 {
jsonResultData, errMarshal := json.Marshal(result[0].Interface())
if errMarshal != nil {
return "", fmt.Errorf("error marshaling result: %w", errMarshal)
var resultBytes bytes.Buffer
enc := json.NewEncoder(&resultBytes)
enc.SetEscapeHTML(false)
err = enc.Encode(result[0].Interface())
if err != nil {
return "", fmt.Errorf("error marshaling result: %w", err)
}
return string(jsonResultData), nil
return strings.TrimSpace(resultBytes.String()), nil
}

return "", nil
Expand Down
94 changes: 66 additions & 28 deletions transformer/dall-e.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ type DallEImageOutput any
type DallEImageSize string

const (
DallEImageSize256 DallEImageSize = openai.CreateImageSize256x256
DallEImageSize512 DallEImageSize = openai.CreateImageSize512x512
DallEImageSize1024 DallEImageSize = openai.CreateImageSize1024x1024
DallEImageSize256x256 DallEImageSize = openai.CreateImageSize256x256
DallEImageSize512x512 DallEImageSize = openai.CreateImageSize512x512
DallEImageSize1024x1024 DallEImageSize = openai.CreateImageSize1024x1024
DallEImageSize1792x104 DallEImageSize = openai.CreateImageSize1792x1024
DallEImageSize1024x1792 DallEImageSize = openai.CreateImageSize1024x1792
)

type DallEImageFormat string
Expand All @@ -30,19 +32,45 @@ const (
DallEImageFormatImage DallEImageFormat = "image"
)

type DallEModel string

const (
DallEModel2 DallEModel = openai.CreateImageModelDallE2
DallEModel3 DallEModel = openai.CreateImageModelDallE3
)

type DallEImageQuality string

const (
DallEImageQualityHD DallEImageQuality = openai.CreateImageQualityHD
DallEImageQualityStandard DallEImageQuality = openai.CreateImageQualityStandard
)

type DallEImageStyle string

const (
DallEImageStyleVivid DallEImageStyle = openai.CreateImageStyleVivid
DallEImageStyleNatural DallEImageStyle = openai.CreateImageStyleNatural
)

type DallE struct {
openAIClient *openai.Client
model DallEModel
imageSize DallEImageSize
imageFormat DallEImageFormat
imageFile string
imageStyle DallEImageStyle
imageQuality DallEImageQuality
}

func NewDallE() *DallE {
openAIKey := os.Getenv("OPENAI_API_KEY")
return &DallE{
openAIClient: openai.NewClient(openAIKey),
imageSize: DallEImageSize256,
model: DallEModel2,
imageSize: DallEImageSize256x256,
imageFormat: DallEImageFormatURL,
imageStyle: DallEImageStyleNatural,
imageQuality: DallEImageQualityStandard,
}
}

Expand All @@ -56,73 +84,83 @@ func (d *DallE) WithImageSize(imageSize DallEImageSize) *DallE {
return d
}

func (d *DallE) AsURL() *DallE {
d.imageFormat = DallEImageFormatURL
func (d *DallE) WithImageStyle(imageStyle DallEImageStyle) *DallE {
d.imageStyle = imageStyle
return d
}

func (d *DallE) AsFile(path string) *DallE {
d.imageFormat = DallEImageFormatFile
d.imageFile = path
func (d *DallE) WithImageQuality(imageQuality DallEImageQuality) *DallE {
d.imageQuality = imageQuality
return d
}

func (d *DallE) AsImage() *DallE {
d.imageFormat = DallEImageFormatImage
func (d *DallE) WithModel(model DallEModel) *DallE {
d.model = model
return d
}

func (d *DallE) WithImageFormat(imageFormat DallEImageFormat) *DallE {
d.imageFormat = imageFormat
return d
}

func (d *DallE) Transform(ctx context.Context, input string) (any, error) {
switch d.imageFormat {
case DallEImageFormatURL:
return d.transformToURL(ctx, input)
return d.TransformAsURL(ctx, input)
case DallEImageFormatFile:
return d.transformToFile(ctx, input)
return d.TransformAsFile(ctx, input, nil)
case DallEImageFormatImage:
return d.transformToImage(ctx, input)
return d.TransformToImage(ctx, input)
default:
return "", fmt.Errorf("unknown image format: %s", d.imageFormat)
}
}

func (d *DallE) transformToURL(ctx context.Context, input string) (any, error) {
func (d *DallE) TransformAsURL(ctx context.Context, input string) (string, error) {
reqURL := openai.ImageRequest{
Prompt: input,
Model: string(d.model),
Size: string(d.imageSize),
Quality: string(d.imageQuality),
Style: string(d.imageStyle),
ResponseFormat: openai.CreateImageResponseFormatURL,
N: 1,
}

respURL, err := d.openAIClient.CreateImage(ctx, reqURL)
if err != nil {
return nil, err
return "", err
}

return respURL.Data[0].URL, nil
}

func (d *DallE) transformToFile(ctx context.Context, input string) (any, error) {
imgData, err := d.transformToImage(ctx, input)
func (d *DallE) TransformAsFile(ctx context.Context, input string, file *os.File) (string, error) {
imgData, err := d.TransformToImage(ctx, input)
if err != nil {
return nil, err
return "", err
}

file, err := os.Create(d.imageFile)
if err != nil {
return nil, err
if file == nil {
// create a temporary file
file, err = os.CreateTemp("", "dall-e-*.png")
if err != nil {
return "", err
}
}

defer file.Close()

err = png.Encode(file, imgData.(image.Image))
err = png.Encode(file, imgData)
if err != nil {
return nil, err
return "", err
}

var output interface{}
return output, nil
return file.Name(), nil
}

func (d *DallE) transformToImage(ctx context.Context, input string) (any, error) {
func (d *DallE) TransformToImage(ctx context.Context, input string) (image.Image, error) {
reqBase64 := openai.ImageRequest{
Prompt: input,
Size: string(d.imageSize),
Expand Down

0 comments on commit cbc913b

Please sign in to comment.