diff --git a/cmd/serve.go b/cmd/serve.go index b3d1650..1c97b4f 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -10,6 +10,7 @@ import ( ) var httpAddress, ldapServer, ldapBaseDN, dbLocation, readUser, readPassword, ldapAdminGroupDB string +var ldapIsAd bool var serveCmd = &cobra.Command{ Use: "serve", @@ -35,7 +36,10 @@ var serveCmd = &cobra.Command{ return err } - srv := server.New(db, ldapServer, ldapBaseDN, readUser, readPassword, ldapAdminGroupDB) + srv, err := server.New(db, ldapServer, ldapBaseDN, readUser, readPassword, ldapAdminGroupDB, ldapIsAd) + if err != nil { + return err + } return srv.Listen(httpAddress) }, @@ -54,6 +58,8 @@ func init() { serveCmd.Flags().StringVarP(&ldapAdminGroupDB, "ldap-admin-group-dn", "g", "", "LDAP group DN to use for identifying administrators") + serveCmd.Flags().BoolVar(&ldapIsAd, "ldap-is-ad", false, "Whether the LDAP server is Active Directory") + if err := serveCmd.MarkFlagRequired("ldap-read-user"); err != nil { log.Fatalln(err) } diff --git a/internal/server/auth_middleware.go b/internal/server/auth_middleware.go index ded46e3..4eaa1cb 100644 --- a/internal/server/auth_middleware.go +++ b/internal/server/auth_middleware.go @@ -35,7 +35,7 @@ func basicAuth(auth string) (string, string, error) { return parts[0], parts[1], nil } -func authMiddleware(authHeader string, l ldap.LDAP) (*ldap.User, error) { +func authMiddleware(authHeader string, l *ldap.LDAP) (*ldap.User, error) { sAMAccountName, password, err := basicAuth(authHeader) if err != nil { return nil, err diff --git a/internal/server/service.go b/internal/server/service.go index 754cff8..4ccf5b4 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -11,22 +11,27 @@ type Server struct { app *fiber.App db *bbolt.DB - ldap ldap.LDAP + ldap *ldap.LDAP ldapAdminGroupDN string } -func New(db *bbolt.DB, ldapServer, ldapBaseDN, ldapReadUser, ldapReadPassword, ldapAdminGroupDN string) *Server { +func New(db *bbolt.DB, ldapServer, ldapBaseDN, ldapReadUser, ldapReadPassword, ldapAdminGroupDN string, isAD bool) (*Server, error) { + l, err := ldap.New(ldapServer, ldapBaseDN, ldapReadUser, ldapReadPassword, isAD) + if err != nil { + return nil, err + } + srv := &Server{ fiber.New(), db, - ldap.New(ldapServer, ldapBaseDN, ldapReadUser, ldapReadPassword), + l, ldapAdminGroupDN, } srv.init() - return srv + return srv, nil } func (s *Server) init() {