-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.go
133 lines (112 loc) · 4.04 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/lambda"
)
type PayloadBuilder interface {
BuildRequest(*http.Request) ([]byte, error)
BuildResponse([]byte) (int, []byte, map[string][]string, error)
}
func main() {
functionName := flag.String("f", "myfunction", "Lambda function name")
bind := flag.String("l", "", "HTTP listen address (default any)")
port := flag.Int("p", 8080, "HTTP listen port")
endpoint := flag.String("e", "", "Lambda API endpoint")
apiType := flag.String("t", "alb", "HTTP gateway type (\"alb\" for ALB)")
albMultiValue := flag.Bool("m", false, "Enable multi-value headers. Effective only with -t alb")
flag.Usage = func() {
fmt.Println("Usage of lambda-local-proxy:")
flag.PrintDefaults()
fmt.Println("")
fmt.Println(" Environment variables:")
fmt.Println(" AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN")
}
flag.Parse()
if *apiType != "alb" {
fmt.Println("Unknown gateway type: " + *apiType)
os.Exit(1)
}
requestFree := make(chan bool, 1)
requestFree <- true
pb := NewALBPayloadBuilder(*albMultiValue)
client := MakeLambdaClient(*endpoint)
handler := MakeInvokeLambdaHandler(client, *functionName, pb, requestFree)
http.HandleFunc("/", handler)
listenAddress := fmt.Sprintf("%s:%d", *bind, *port)
log.Fatal(http.ListenAndServe(listenAddress, nil))
}
func MakeLambdaClient(endpoint string) *lambda.Lambda {
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
config := aws.Config{}
if endpoint != "" {
config.Endpoint = &endpoint
}
return lambda.New(sess, &config)
}
func MakeInvokeLambdaHandler(client *lambda.Lambda, functionName string, pb PayloadBuilder, requestFree chan bool) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
// Use the requestFree channel as a lock to prevent more than one inflight request to the lambda function
// since it has a concurrency of one.
_, ok := <-requestFree
if !ok {
return // Indicates channel closure
}
defer func () {requestFree <- true}()
// Add proxy headers
r.Header.Add("X-Forwarded-For", r.RemoteAddr[0:strings.LastIndex(r.RemoteAddr, ":")])
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Port", "8080")
// Parse HTTP response and create an event
payload, err := pb.BuildRequest(r)
if err != nil {
WriteErrorResponse(w, "Invalid request", err)
return
}
// Invoke Lambda with the event
output, err := client.Invoke(&lambda.InvokeInput{
FunctionName: aws.String(functionName),
Payload: payload,
})
if err != nil {
WriteErrorResponse(w, "Failed to invoke Lambda", err)
return
}
if output.FunctionError != nil {
WriteErrorResponse(w, "Lambda function error: " + *output.FunctionError, nil)
return
}
// Build a response
status, body, headers, err := pb.BuildResponse(output.Payload)
if err != nil {
WriteErrorResponse(w, "Invalid JSON response", err)
return
}
// Write the response - headers, status code, and body
for key, values := range headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(status)
w.Write(body)
return
}
}
func WriteErrorResponse(w http.ResponseWriter, message string, err error) {
body := "502 Bad Gateway\n" + message
if err != nil {
body += "\n" + err.Error()
}
w.WriteHeader(502) // Bad Gateway
w.Write([]byte(body))
return
}