Skip to content

Commit

Permalink
Added post hook argument
Browse files Browse the repository at this point in the history
  • Loading branch information
juulrecognize committed Jan 21, 2021
1 parent 3ffef9e commit e72c85b
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 52 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ GLOBAL OPTIONS:
--domain value, -n value The tld [$CF_ORIGIN_TLD]
--certout value, -c value Certificate output file name (default: "./cert.pem") [$CF_ORIGIN_CERT_OUT_FILE]
--keyout value, -k value Private key output file name (default: "./key.pem") [$CF_ORIGIN_KEY_OUT_FILE]
--post-hook value, --ph value Post hook [$CF_ORIGIN_POST_HOOK]
--help, -h show help
--version, -v print the version
```
72 changes: 33 additions & 39 deletions cmd/origin-cert-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"io/ioutil"
"log"
"os"
"os/signal"
Expand Down Expand Up @@ -58,6 +57,12 @@ func main() {
Value: `./key.pem`,
EnvVar: "CF_ORIGIN_KEY_OUT_FILE",
},
cli.StringFlag{
Name: "post-hook, ph",
Usage: "Post hook",
Value: "",
EnvVar: "CF_ORIGIN_POST_HOOK",
},
}
app.Action = Start

Expand All @@ -66,7 +71,7 @@ func main() {
}
}

func Start(c *cli.Context) error {
func Start(c *cli.Context) {
if len(c.String(`origin-api-key`)) <= 0 {
log.Fatal("origin-api-key is a required parameter")
}
Expand All @@ -84,48 +89,37 @@ func Start(c *cli.Context) error {
sigchan := make(chan os.Signal, 1)
signal.Notify(sigchan, os.Interrupt)

ca, err := agent.NewCertAgent(c.String(`origin-api-key`), c.Duration(`rotation-frequency`), c.Int(`ttl`))
ca, err := agent.NewCertAgent(
c.String(`origin-api-key`),
c.Duration(`rotation-frequency`),
c.Int(`ttl`),
agent.NewFilesystemCertificateWriter(c.String("certout"), c.String("keyout")),
)
if err != nil {
log.Fatal(err)
}

postHook := c.String("post-hook")
if postHook != "" {
ca.Attach(agent.NewPostHookObserver(postHook))
}

ctx, cancel := context.WithCancel(context.Background())
go ca.Run(ctx, c.String(`domain`))
var lastID string
for {
select {
case <-time.After(5 * time.Second):
creds, err := ca.GetCertKeyPair(0)
if creds.ID == lastID {
continue
}
lastID = creds.ID
if err != nil {
log.Print(err)
continue
}
err = ioutil.WriteFile(c.String(`certout`), creds.CertPEM, 0600)
if err != nil {
log.Printf("unable to write certificate file: %x", err)
}
err = ioutil.WriteFile(c.String(`keyout`), creds.Key, 0600)
if err != nil {
log.Printf("unable to write key file: %x", err)
}

case <-sigchan:
err = os.Remove(c.String(`certout`))
if err != nil {
// bugsnag report?
log.Printf("unable to clean up revoked certificate file: %x", err)
}
err = os.Remove(c.String(`keyout`))
if err != nil {
// bugsnag report?
log.Printf("unable to clean up key file for revoked: %x", err)
}
cancel()
time.Sleep(ShortTick)
os.Exit(127)
}
<-sigchan

err = os.Remove(c.String(`certout`))
if err != nil {
// bugsnag report?
log.Printf("unable to clean up revoked certificate file: %x", err)
}
err = os.Remove(c.String(`keyout`))
if err != nil {
// bugsnag report?
log.Printf("unable to clean up key file for revoked: %x", err)
}
cancel()
time.Sleep(ShortTick)
os.Exit(127)
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go 1.12

require (
github.com/allingeek/cloudflare-go v0.8.6-0.20190316173248-189614cf3ffc
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/ionrock/procs v0.0.0-20180102005558-f53ef5630f1a
github.com/pkg/errors v0.8.1 // indirect
github.com/urfave/cli v1.20.0
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ github.com/cloudflare/cloudflare-go v0.8.5 h1:k1iz+H2jIL8OnS+bGhNQ6GPldi7VCo2tuW
github.com/cloudflare/cloudflare-go v0.8.5/go.mod h1:8KhU6K+zHUEWOSU++mEQYf7D9UZOcQcibUoSm6vCUz4=
github.com/codegangsta/cli v1.20.0/go.mod h1:/qJNoX69yVSKu5o4jLyXAENLRyk1uhi7zkbQ3slBdOA=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/ionrock/procs v0.0.0-20180102005558-f53ef5630f1a h1:o1qg0zExjAjMOFL5kHWJFOyqY/jVIFnx1rzMFt9MD+g=
github.com/ionrock/procs v0.0.0-20180102005558-f53ef5630f1a/go.mod h1:ZANLPvV4k0ZsE7hitAAQHZ8vmst7X8wnmuTi4QS1gYw=
github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down
7 changes: 7 additions & 0 deletions internal/instanceof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package internal

import "reflect"

func IsInstanceOf(objectPtr, typePtr interface{}) bool {
return reflect.TypeOf(objectPtr) == reflect.TypeOf(typePtr)
}
42 changes: 29 additions & 13 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"github.com/gotopple/cf-origin-cert/pkg/observer"
"log"
"os"
"sync"
Expand All @@ -25,6 +26,8 @@ type CertKeyPair struct {
}
type CertAgent struct {
sync.Mutex
*observer.Observable
writer CertificateWriter
apiKey string
period time.Duration
validity int
Expand All @@ -33,7 +36,7 @@ type CertAgent struct {
cache []CertKeyPair
}

func NewCertAgent(apiKey string, period time.Duration, validity int) (*CertAgent, error) {
func NewCertAgent(apiKey string, period time.Duration, validity int, writer CertificateWriter) (*CertAgent, error) {
switch validity {
case Week:
case Month:
Expand All @@ -49,14 +52,18 @@ func NewCertAgent(apiKey string, period time.Duration, validity int) (*CertAgent
}

generator := SHA256RSAGenerator{}
return &CertAgent{
apiKey: apiKey,
period: period,
validity: validity,
api: api,
generator: generator,
cache: []CertKeyPair{},
}, nil
agent := &CertAgent{
Observable: observer.MakeObservable(),
writer: writer,
apiKey: apiKey,
period: period,
validity: validity,
api: api,
generator: generator,
cache: []CertKeyPair{},
}

return agent, nil
}

func (a *CertAgent) Run(ctx context.Context, domain string) {
Expand All @@ -65,7 +72,7 @@ func (a *CertAgent) Run(ctx context.Context, domain string) {
subject := fmt.Sprintf("*.%s", domain)
dnsNames := []string{domain, subject}

generate := func() {
generateAndWrite := func() *CertKeyPair {
key, pem, err := a.generator.GenerateNewPEM(subject, dnsNames)
if err != nil {
log.Fatal(err)
Expand All @@ -83,8 +90,13 @@ func (a *CertAgent) Run(ctx context.Context, domain string) {
log.Fatal(err)
}

result := CertKeyPair{ID: cert.ID, CertPEM: []byte(cert.Certificate), Key: key}
a.writer.Write(&result)

// prepend to cache
a.cache = append([]CertKeyPair{CertKeyPair{ID: cert.ID, CertPEM: []byte(cert.Certificate), Key: key}}, a.cache...)
a.cache = append([]CertKeyPair{result}, a.cache...)

return &result
}

cleanup := func(all bool) {
Expand Down Expand Up @@ -116,12 +128,16 @@ func (a *CertAgent) Run(ctx context.Context, domain string) {
}
}

generate()
initialCert := generateAndWrite()
a.Notify(initialCert)

for {
select {
case <-time.After(a.period):
generate()
newCert := generateAndWrite()
cleanup(false)

a.Notify(newCert)
case <-ctx.Done():
cleanup(true)
return
Expand Down
5 changes: 5 additions & 0 deletions pkg/agent/certificate_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package agent

type CertificateWriter interface {
Write(certKeyPair *CertKeyPair)
}
36 changes: 36 additions & 0 deletions pkg/agent/filesystem_certificate_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package agent

import (
"io/ioutil"
"log"
)

type FilesystemCertificateWriter struct {
certOutputPath string
keyOutputPath string
lastID string
}

func NewFilesystemCertificateWriter(certOutputPath string, keyOutputPath string) CertificateWriter {
return &FilesystemCertificateWriter{
certOutputPath: certOutputPath,
keyOutputPath: keyOutputPath,
}
}

func (w *FilesystemCertificateWriter) Write(certKeyPair *CertKeyPair) {
if certKeyPair.ID == w.lastID {
return
}
w.lastID = certKeyPair.ID

err := ioutil.WriteFile(w.certOutputPath, certKeyPair.CertPEM, 0600)
if err != nil {
log.Printf("unable to write certificate file: %x", err)
}

err = ioutil.WriteFile(w.keyOutputPath, certKeyPair.Key, 0600)
if err != nil {
log.Printf("unable to write key file: %x", err)
}
}
53 changes: 53 additions & 0 deletions pkg/agent/post_hook_observer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package agent

import (
"fmt"
"github.com/gotopple/cf-origin-cert/internal"
"github.com/ionrock/procs"
"log"
)

type PostHookObserver struct {
ch chan interface{}
postHookCommand string
}

func NewPostHookObserver(command string) *PostHookObserver {
result := &PostHookObserver{
make(chan interface{}, 2),
command,
}

go func() {
for {
evt := <-result.ch
if internal.IsInstanceOf(evt, (*CertKeyPair)(nil)) {
result.onNewCertificate()
}
}
}()

return result
}

func (a *PostHookObserver) GetChannel() chan interface{} {
return a.ch
}

func (a *PostHookObserver) onNewCertificate() {
p := procs.NewProcess(a.postHookCommand)
p.ErrHandler = func(line string) string {
fmt.Printf("[POST HOOK - STD ERR] %s\n", line)
return line
}
p.OutputHandler = func(line string) string {
fmt.Printf("[POST HOOK - STD OUT] %s\n", line)
return line
}
err := p.Run()

if err != nil {
log.Fatal("Post hook failed", err)
}
fmt.Println("Executed post hook")
}
41 changes: 41 additions & 0 deletions pkg/observer/observable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package observer

import (
"sync"
)

type Observable struct {
observers []chan interface{}
mu *sync.Mutex
}

func MakeObservable() *Observable {
return &Observable{
observers: make([]chan interface{}, 0),
mu: &sync.Mutex{},
}
}

func (o *Observable) Attach(observer Observer) {
o.mu.Lock()
defer o.mu.Unlock()
o.observers = append(o.observers, observer.GetChannel())
}

func (o *Observable) Detach(observer Observer) {
o.mu.Lock()
defer o.mu.Unlock()
for i, v := range o.observers {
c := observer.GetChannel()
if v == c {
o.observers = append(o.observers[:i], o.observers[i+1:]...)
return
}
}
}

func (o *Observable) Notify(evt interface{}) {
for _, v := range o.observers {
v <- evt
}
}
5 changes: 5 additions & 0 deletions pkg/observer/observer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package observer

type Observer interface {
GetChannel() chan interface{}
}

0 comments on commit e72c85b

Please sign in to comment.