Skip to content

Commit

Permalink
login1: Add RebootWithContext method
Browse files Browse the repository at this point in the history
Existing Reboot() method does not allow using context not inspecting
D-Bus call errors, which makes it difficult to debug and use.

This commit adds new RebootWithContext() method which addresses those
shortcomings.

Closes coreos#387

Signed-off-by: Mateusz Gozdek <mgozdek@microsoft.com>
  • Loading branch information
invidian committed Jan 10, 2022
1 parent 9a42b6b commit 1a186bc
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 0 deletions.
11 changes: 11 additions & 0 deletions login1/dbus.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package login1

import (
"context"
"fmt"
"os"
"strconv"
Expand Down Expand Up @@ -59,6 +60,7 @@ type connectionManager interface {
type Caller interface {
// TODO: This method should eventually be removed, as it provides no context support.
Call(method string, flags dbus.Flags, args ...interface{}) *dbus.Call
CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call
}

// New establishes a connection to the system bus and authenticates.
Expand Down Expand Up @@ -347,6 +349,15 @@ func (c *Conn) Reboot(askForAuth bool) {
c.object.Call(dbusInterface+".Reboot", 0, askForAuth)
}

// Reboot asks logind for a reboot using given context, optionally asking for auth.
func (c *Conn) RebootWithContext(ctx context.Context, askForAuth bool) error {
if call := c.object.CallWithContext(ctx, dbusInterface+".Reboot", 0, askForAuth); call.Err != nil {
return fmt.Errorf("calling reboot: %w", call.Err)
}

return nil
}

// Inhibit takes inhibition lock in logind.
func (c *Conn) Inhibit(what, who, why, mode string) (*os.File, error) {
var fd dbus.UnixFD
Expand Down
231 changes: 231 additions & 0 deletions login1/dbus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package login1_test

import (
"context"
"errors"
"fmt"
"os/user"
"regexp"
Expand Down Expand Up @@ -142,6 +144,168 @@ func Test_Creating_new_connection_with_custom_connection(t *testing.T) {
})
}

//nolint:funlen // Many subtests.
func Test_Rebooting_with_context(t *testing.T) {
t.Parallel()

t.Run("calls_login1_reboot_method_on_manager_interface", func(t *testing.T) {
t.Parallel()

rebootCalled := false

askForReboot := false

connectionWithContextCheck := &mockConnection{
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
return &mockObject{
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
rebootCalled = true

expectedMethodName := "org.freedesktop.login1.Manager.Reboot"

if method != expectedMethodName {
t.Fatalf("Expected method %q being called, got %q", expectedMethodName, method)
}

if len(args) != 1 {
t.Fatalf("Expected one argument to call, got %q", args)
}

askedForReboot, ok := args[0].(bool)
if !ok {
t.Fatalf("Expected first argument to be of type %T, got %T", askForReboot, args[0])
}

if askForReboot != askedForReboot {
t.Fatalf("Expected argument to be %t, got %t", askForReboot, askedForReboot)
}

return &dbus.Call{}
},
}
},
}

testConn, err := login1.NewWithConnection(connectionWithContextCheck)
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}

if err := testConn.RebootWithContext(context.Background(), askForReboot); err != nil {
t.Fatalf("Unexpected error rebooting: %v", err)
}

if !rebootCalled {
t.Fatalf("Expected reboot method call on given D-Bus connection")
}
})

t.Run("asks_for_auth_when_requested", func(t *testing.T) {
t.Parallel()

rebootCalled := false

askForReboot := true

connectionWithContextCheck := &mockConnection{
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
return &mockObject{
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
rebootCalled = true

if len(args) != 1 {
t.Fatalf("Expected one argument to call, got %q", args)
}

askedForReboot, ok := args[0].(bool)
if !ok {
t.Fatalf("Expected first argument to be of type %T, got %T", askForReboot, args[0])
}

if askForReboot != askedForReboot {
t.Fatalf("Expected argument to be %t, got %t", askForReboot, askedForReboot)
}

return &dbus.Call{}
},
}
},
}

testConn, err := login1.NewWithConnection(connectionWithContextCheck)
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}

if err := testConn.RebootWithContext(context.Background(), askForReboot); err != nil {
t.Fatalf("Unexpected error rebooting: %v", err)
}

if !rebootCalled {
t.Fatalf("Expected reboot method call on given D-Bus connection")
}
})

t.Run("use_given_context_for_D-Bus_call", func(t *testing.T) {
t.Parallel()

testKey := struct{}{}
expectedValue := "bar"

ctx := context.WithValue(context.Background(), testKey, expectedValue)

connectionWithContextCheck := &mockConnection{
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
return &mockObject{
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
if val := ctx.Value(testKey); val != expectedValue {
t.Fatalf("Got unexpected context on call")
}

return &dbus.Call{}
},
}
},
}

testConn, err := login1.NewWithConnection(connectionWithContextCheck)
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}

if err := testConn.RebootWithContext(ctx, false); err != nil {
t.Fatalf("Unexpected error rebooting: %v", err)
}
})

t.Run("returns_error_when_D-Bus_call_fails", func(t *testing.T) {
t.Parallel()

expectedError := fmt.Errorf("reboot error")

connectionWithFailingObjectCall := &mockConnection{
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
return &mockObject{
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
return &dbus.Call{
Err: expectedError,
}
},
}
},
}

testConn, err := login1.NewWithConnection(connectionWithFailingObjectCall)
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}

if err := testConn.RebootWithContext(context.Background(), false); !errors.Is(err, expectedError) {
t.Fatalf("Unexpected error rebooting: %v", err)
}
})
}

// mockConnection is a test helper for mocking dbus.Conn.
type mockConnection struct {
ObjectF func(string, dbus.ObjectPath) dbus.BusObject
Expand Down Expand Up @@ -178,3 +342,70 @@ func (m *mockConnection) Close() error {
func (m *mockConnection) BusObject() dbus.BusObject {
return nil
}

// mockObject is a mock of dbus.BusObject.
type mockObject struct {
CallWithContextF func(context.Context, string, dbus.Flags, ...interface{}) *dbus.Call
CallF func(string, dbus.Flags, ...interface{}) *dbus.Call
}

// mockObject must implement dbus.BusObject to be usable for other packages in tests, though not
// all methods must actually be mockable. See https://github.com/dbus/dbus/issues/252 for details.
var _ dbus.BusObject = &mockObject{}

// CallWithContext ...
//
//nolint:lll // Upstream signature, can't do much with that.
func (m *mockObject) CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
if m.CallWithContextF == nil {
return &dbus.Call{}
}

return m.CallWithContextF(ctx, method, flags, args...)
}

// Call ...
func (m *mockObject) Call(method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
if m.CallF == nil {
return &dbus.Call{}
}

return m.CallF(method, flags, args...)
}

// Go ...
func (m *mockObject) Go(method string, flags dbus.Flags, ch chan *dbus.Call, args ...interface{}) *dbus.Call {
return &dbus.Call{}
}

// GoWithContext ...
//
//nolint:lll // Upstream signature, can't do much with that.
func (m *mockObject) GoWithContext(ctx context.Context, method string, flags dbus.Flags, ch chan *dbus.Call, args ...interface{}) *dbus.Call {
return &dbus.Call{}
}

// AddMatchSignal ...
func (m *mockObject) AddMatchSignal(iface, member string, options ...dbus.MatchOption) *dbus.Call {
return &dbus.Call{}
}

// RemoveMatchSignal ...
func (m *mockObject) RemoveMatchSignal(iface, member string, options ...dbus.MatchOption) *dbus.Call {
return &dbus.Call{}
}

// GetProperty ...
func (m *mockObject) GetProperty(p string) (dbus.Variant, error) { return dbus.Variant{}, nil }

// StoreProperty ...
func (m *mockObject) StoreProperty(p string, value interface{}) error { return nil }

// SetProperty ...
func (m *mockObject) SetProperty(p string, v interface{}) error { return nil }

// Destination ...
func (m *mockObject) Destination() string { return "" }

// Path ...
func (m *mockObject) Path() dbus.ObjectPath { return "" }

0 comments on commit 1a186bc

Please sign in to comment.