diff --git a/README.md b/README.md index b19641c..08c577e 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ GLOBAL OPTIONS: --domain value, -n value The tld [$CF_ORIGIN_TLD] --certout value, -c value Certificate output file name (default: "./cert.pem") [$CF_ORIGIN_CERT_OUT_FILE] --keyout value, -k value Private key output file name (default: "./key.pem") [$CF_ORIGIN_KEY_OUT_FILE] + --post-hook value, --ph value Post hook [$CF_ORIGIN_POST_HOOK] --help, -h show help --version, -v print the version ``` diff --git a/cmd/origin-cert-agent/main.go b/cmd/origin-cert-agent/main.go index fbf1107..eabb569 100644 --- a/cmd/origin-cert-agent/main.go +++ b/cmd/origin-cert-agent/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "io/ioutil" "log" "os" "os/signal" @@ -58,6 +57,12 @@ func main() { Value: `./key.pem`, EnvVar: "CF_ORIGIN_KEY_OUT_FILE", }, + cli.StringFlag{ + Name: "post-hook, ph", + Usage: "Post hook", + Value: "", + EnvVar: "CF_ORIGIN_POST_HOOK", + }, } app.Action = Start @@ -66,7 +71,7 @@ func main() { } } -func Start(c *cli.Context) error { +func Start(c *cli.Context) { if len(c.String(`origin-api-key`)) <= 0 { log.Fatal("origin-api-key is a required parameter") } @@ -84,48 +89,37 @@ func Start(c *cli.Context) error { sigchan := make(chan os.Signal, 1) signal.Notify(sigchan, os.Interrupt) - ca, err := agent.NewCertAgent(c.String(`origin-api-key`), c.Duration(`rotation-frequency`), c.Int(`ttl`)) + ca, err := agent.NewCertAgent( + c.String(`origin-api-key`), + c.Duration(`rotation-frequency`), + c.Int(`ttl`), + agent.NewFilesystemCertificateWriter(c.String("certout"), c.String("keyout")), + ) if err != nil { log.Fatal(err) } + + postHook := c.String("post-hook") + if postHook != "" { + ca.Attach(agent.NewPostHookObserver(postHook)) + } + ctx, cancel := context.WithCancel(context.Background()) go ca.Run(ctx, c.String(`domain`)) - var lastID string - for { - select { - case <-time.After(5 * time.Second): - creds, err := ca.GetCertKeyPair(0) - if creds.ID == lastID { - continue - } - lastID = creds.ID - if err != nil { - log.Print(err) - continue - } - err = ioutil.WriteFile(c.String(`certout`), creds.CertPEM, 0600) - if err != nil { - log.Printf("unable to write certificate file: %x", err) - } - err = ioutil.WriteFile(c.String(`keyout`), creds.Key, 0600) - if err != nil { - log.Printf("unable to write key file: %x", err) - } - case <-sigchan: - err = os.Remove(c.String(`certout`)) - if err != nil { - // bugsnag report? - log.Printf("unable to clean up revoked certificate file: %x", err) - } - err = os.Remove(c.String(`keyout`)) - if err != nil { - // bugsnag report? - log.Printf("unable to clean up key file for revoked: %x", err) - } - cancel() - time.Sleep(ShortTick) - os.Exit(127) - } + <-sigchan + + err = os.Remove(c.String(`certout`)) + if err != nil { + // bugsnag report? + log.Printf("unable to clean up revoked certificate file: %x", err) + } + err = os.Remove(c.String(`keyout`)) + if err != nil { + // bugsnag report? + log.Printf("unable to clean up key file for revoked: %x", err) } + cancel() + time.Sleep(ShortTick) + os.Exit(127) } diff --git a/go.mod b/go.mod index a560500..e7f9ce4 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.12 require ( github.com/allingeek/cloudflare-go v0.8.6-0.20190316173248-189614cf3ffc + github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect + github.com/ionrock/procs v0.0.0-20180102005558-f53ef5630f1a github.com/pkg/errors v0.8.1 // indirect github.com/urfave/cli v1.20.0 golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect diff --git a/go.sum b/go.sum index 3ab5922..abd15bc 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,10 @@ github.com/cloudflare/cloudflare-go v0.8.5 h1:k1iz+H2jIL8OnS+bGhNQ6GPldi7VCo2tuW github.com/cloudflare/cloudflare-go v0.8.5/go.mod h1:8KhU6K+zHUEWOSU++mEQYf7D9UZOcQcibUoSm6vCUz4= github.com/codegangsta/cli v1.20.0/go.mod h1:/qJNoX69yVSKu5o4jLyXAENLRyk1uhi7zkbQ3slBdOA= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/ionrock/procs v0.0.0-20180102005558-f53ef5630f1a h1:o1qg0zExjAjMOFL5kHWJFOyqY/jVIFnx1rzMFt9MD+g= +github.com/ionrock/procs v0.0.0-20180102005558-f53ef5630f1a/go.mod h1:ZANLPvV4k0ZsE7hitAAQHZ8vmst7X8wnmuTi4QS1gYw= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/instanceof.go b/internal/instanceof.go new file mode 100644 index 0000000..967c67c --- /dev/null +++ b/internal/instanceof.go @@ -0,0 +1,7 @@ +package internal + +import "reflect" + +func IsInstanceOf(objectPtr, typePtr interface{}) bool { + return reflect.TypeOf(objectPtr) == reflect.TypeOf(typePtr) +} diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 91c0133..2c026ee 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/gotopple/cf-origin-cert/pkg/observer" "log" "os" "sync" @@ -25,6 +26,8 @@ type CertKeyPair struct { } type CertAgent struct { sync.Mutex + *observer.Observable + writer CertificateWriter apiKey string period time.Duration validity int @@ -33,7 +36,7 @@ type CertAgent struct { cache []CertKeyPair } -func NewCertAgent(apiKey string, period time.Duration, validity int) (*CertAgent, error) { +func NewCertAgent(apiKey string, period time.Duration, validity int, writer CertificateWriter) (*CertAgent, error) { switch validity { case Week: case Month: @@ -49,14 +52,18 @@ func NewCertAgent(apiKey string, period time.Duration, validity int) (*CertAgent } generator := SHA256RSAGenerator{} - return &CertAgent{ - apiKey: apiKey, - period: period, - validity: validity, - api: api, - generator: generator, - cache: []CertKeyPair{}, - }, nil + agent := &CertAgent{ + Observable: observer.MakeObservable(), + writer: writer, + apiKey: apiKey, + period: period, + validity: validity, + api: api, + generator: generator, + cache: []CertKeyPair{}, + } + + return agent, nil } func (a *CertAgent) Run(ctx context.Context, domain string) { @@ -65,7 +72,7 @@ func (a *CertAgent) Run(ctx context.Context, domain string) { subject := fmt.Sprintf("*.%s", domain) dnsNames := []string{domain, subject} - generate := func() { + generateAndWrite := func() *CertKeyPair { key, pem, err := a.generator.GenerateNewPEM(subject, dnsNames) if err != nil { log.Fatal(err) @@ -83,8 +90,13 @@ func (a *CertAgent) Run(ctx context.Context, domain string) { log.Fatal(err) } + result := CertKeyPair{ID: cert.ID, CertPEM: []byte(cert.Certificate), Key: key} + a.writer.Write(&result) + // prepend to cache - a.cache = append([]CertKeyPair{CertKeyPair{ID: cert.ID, CertPEM: []byte(cert.Certificate), Key: key}}, a.cache...) + a.cache = append([]CertKeyPair{result}, a.cache...) + + return &result } cleanup := func(all bool) { @@ -116,12 +128,16 @@ func (a *CertAgent) Run(ctx context.Context, domain string) { } } - generate() + initialCert := generateAndWrite() + a.Notify(initialCert) + for { select { case <-time.After(a.period): - generate() + newCert := generateAndWrite() cleanup(false) + + a.Notify(newCert) case <-ctx.Done(): cleanup(true) return diff --git a/pkg/agent/certificate_writer.go b/pkg/agent/certificate_writer.go new file mode 100644 index 0000000..0a0b19e --- /dev/null +++ b/pkg/agent/certificate_writer.go @@ -0,0 +1,5 @@ +package agent + +type CertificateWriter interface { + Write(certKeyPair *CertKeyPair) +} diff --git a/pkg/agent/filesystem_certificate_writer.go b/pkg/agent/filesystem_certificate_writer.go new file mode 100644 index 0000000..90e8856 --- /dev/null +++ b/pkg/agent/filesystem_certificate_writer.go @@ -0,0 +1,36 @@ +package agent + +import ( + "io/ioutil" + "log" +) + +type FilesystemCertificateWriter struct { + certOutputPath string + keyOutputPath string + lastID string +} + +func NewFilesystemCertificateWriter(certOutputPath string, keyOutputPath string) CertificateWriter { + return &FilesystemCertificateWriter{ + certOutputPath: certOutputPath, + keyOutputPath: keyOutputPath, + } +} + +func (w *FilesystemCertificateWriter) Write(certKeyPair *CertKeyPair) { + if certKeyPair.ID == w.lastID { + return + } + w.lastID = certKeyPair.ID + + err := ioutil.WriteFile(w.certOutputPath, certKeyPair.CertPEM, 0600) + if err != nil { + log.Printf("unable to write certificate file: %x", err) + } + + err = ioutil.WriteFile(w.keyOutputPath, certKeyPair.Key, 0600) + if err != nil { + log.Printf("unable to write key file: %x", err) + } +} diff --git a/pkg/agent/post_hook_observer.go b/pkg/agent/post_hook_observer.go new file mode 100644 index 0000000..d790fa5 --- /dev/null +++ b/pkg/agent/post_hook_observer.go @@ -0,0 +1,53 @@ +package agent + +import ( + "fmt" + "github.com/gotopple/cf-origin-cert/internal" + "github.com/ionrock/procs" + "log" +) + +type PostHookObserver struct { + ch chan interface{} + postHookCommand string +} + +func NewPostHookObserver(command string) *PostHookObserver { + result := &PostHookObserver{ + make(chan interface{}, 2), + command, + } + + go func() { + for { + evt := <-result.ch + if internal.IsInstanceOf(evt, (*CertKeyPair)(nil)) { + result.onNewCertificate() + } + } + }() + + return result +} + +func (a *PostHookObserver) GetChannel() chan interface{} { + return a.ch +} + +func (a *PostHookObserver) onNewCertificate() { + p := procs.NewProcess(a.postHookCommand) + p.ErrHandler = func(line string) string { + fmt.Printf("[POST HOOK - STD ERR] %s\n", line) + return line + } + p.OutputHandler = func(line string) string { + fmt.Printf("[POST HOOK - STD OUT] %s\n", line) + return line + } + err := p.Run() + + if err != nil { + log.Fatal("Post hook failed", err) + } + fmt.Println("Executed post hook") +} diff --git a/pkg/observer/observable.go b/pkg/observer/observable.go new file mode 100644 index 0000000..7a943fa --- /dev/null +++ b/pkg/observer/observable.go @@ -0,0 +1,41 @@ +package observer + +import ( + "sync" +) + +type Observable struct { + observers []chan interface{} + mu *sync.Mutex +} + +func MakeObservable() *Observable { + return &Observable{ + observers: make([]chan interface{}, 0), + mu: &sync.Mutex{}, + } +} + +func (o *Observable) Attach(observer Observer) { + o.mu.Lock() + defer o.mu.Unlock() + o.observers = append(o.observers, observer.GetChannel()) +} + +func (o *Observable) Detach(observer Observer) { + o.mu.Lock() + defer o.mu.Unlock() + for i, v := range o.observers { + c := observer.GetChannel() + if v == c { + o.observers = append(o.observers[:i], o.observers[i+1:]...) + return + } + } +} + +func (o *Observable) Notify(evt interface{}) { + for _, v := range o.observers { + v <- evt + } +} diff --git a/pkg/observer/observer.go b/pkg/observer/observer.go new file mode 100644 index 0000000..afea331 --- /dev/null +++ b/pkg/observer/observer.go @@ -0,0 +1,5 @@ +package observer + +type Observer interface { + GetChannel() chan interface{} +}