diff --git a/net/multi_listen.go b/net/multi_listen.go new file mode 100644 index 00000000..d5088052 --- /dev/null +++ b/net/multi_listen.go @@ -0,0 +1,182 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "fmt" + "net" + "sync" +) + +// connErrPair pairs conn and error which is returned by accept on sub-listeners. +type connErrPair struct { + conn net.Conn + err error +} + +// multiListener implements net.Listener +type multiListener struct { + listeners []net.Listener + + wg sync.WaitGroup + mu sync.Mutex + closed bool + + // connErrQueue holds the connections accepted by sub-listeners + connErrQueue []connErrPair + + // acceptReadyCh is used as a semaphore to wake up the waiting + // multiListener.Accept() when new connections are available + acceptReadyCh chan any +} + +// compile time check to ensure *multiListener implements net.Listener +var _ net.Listener = &multiListener{} + +// MultiListen returns net.Listener which can listen on and accept connections for +// the given network on multiple addresses. Internally it uses stdlib to create +// sub-listener and multiplexes connection requests using go-routines. +// The network must be "tcp", "tcp4" or "tcp6". +// It follows the semantics of net.Listen that primarily means: +// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen +// listens on all available unicast and anycast IP addresses of the local system. +// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively. +// 3. The host can accept names (e.g, localhost) and it will create a listener for at +// most one of the host's IP. +func MultiListen(ctx context.Context, network string, addrs []string) (net.Listener, error) { + return multiListen( + ctx, + network, + addrs, + func(ctx context.Context, network, address string) (net.Listener, error) { + var lc net.ListenConfig + return lc.Listen(ctx, network, address) + }) +} + +// multiListen implements MultiListen by consuming stdlib functions as dependency allowing +// mocking for unit-testing. +func multiListen( + ctx context.Context, + network string, + addrs []string, + listenFunc func(ctx context.Context, network, address string) (net.Listener, error), +) (net.Listener, error) { + if !(network == "tcp" || network == "tcp4" || network == "tcp6") { + return nil, fmt.Errorf("network '%s' not supported", network) + } + if len(addrs) == 0 { + return nil, fmt.Errorf("no address provided to listen on") + } + + ml := &multiListener{ + acceptReadyCh: make(chan any), + } + + for _, addr := range addrs { + l, err := listenFunc(ctx, network, addr) + if err != nil { + // close all the sub-listeners and exit + _ = ml.Close() + return nil, err + } + ml.listeners = append(ml.listeners, l) + } + + for _, l := range ml.listeners { + ml.wg.Add(1) + go func(l net.Listener) { + defer ml.wg.Done() + for { + conn, err := l.Accept() + ml.mu.Lock() + if ml.closed { + ml.mu.Unlock() + return + } + // enqueue the accepted connection + ml.connErrQueue = append(ml.connErrQueue, connErrPair{conn: conn, err: err}) + + // signal the waiting ml.Accept() to consume accepted connection from the queue. + select { + case ml.acceptReadyCh <- struct{}{}: + default: + } + ml.mu.Unlock() + } + }(l) + } + return ml, nil +} + +// Accept implements net.Listener. +// It waits for and returns a connection from any of the sub-listener. +func (ml *multiListener) Accept() (net.Conn, error) { + for { + // atomically return and remove the first element of the queue if it's not empty + ml.mu.Lock() + if len(ml.connErrQueue) > 0 { + connErr := ml.connErrQueue[0] + ml.connErrQueue = ml.connErrQueue[1:] + ml.mu.Unlock() + return connErr.conn, connErr.err + } + ml.mu.Unlock() + + // wait for any sub-listener to enqueue an accepted connection + _, ok := <-ml.acceptReadyCh + if !ok { + // The "acceptReadyCh" channel will be closed only when Close() is called on the multiListener. + // Closing of this channel implies that all sub-listeners are also closed, which causes a + // "use of closed network connection" error on their Accept() calls. We return the same error + // for multiListener.Accept() if multiListener.Close() has already been called. + return nil, fmt.Errorf("use of closed network connection") + } + } +} + +// Close implements net.Listener. +// It will close all sub-listeners and wait for the go-routines to exit. +func (ml *multiListener) Close() error { + ml.mu.Lock() + if ml.closed { + ml.mu.Unlock() + return fmt.Errorf("use of closed network connection") + } + ml.closed = true + close(ml.acceptReadyCh) + ml.mu.Unlock() + + // Closing the listeners causes Accept() to immediately return an error, + // which serves as the exit condition for the sub-listener go-routines. + for _, l := range ml.listeners { + _ = l.Close() + } + + // Wait for all the sub-listener go-routines to exit. + ml.wg.Wait() + return nil +} + +// Addr is an implementation of the net.Listener interface. +// It always returns the address of the first listener. +// Callers should use conn.LocalAddr() to obtain the actual +// local address of the sub-listener. +func (ml *multiListener) Addr() net.Addr { + return ml.listeners[0].Addr() +} diff --git a/net/multi_listen_test.go b/net/multi_listen_test.go new file mode 100644 index 00000000..885e7016 --- /dev/null +++ b/net/multi_listen_test.go @@ -0,0 +1,440 @@ +package net + +import ( + "context" + "fmt" + "net" + "reflect" + "strconv" + "sync/atomic" + "testing" + "time" +) + +type fakeCon struct { + remoteAddr net.Addr +} + +func (f *fakeCon) Read(_ []byte) (n int, err error) { + return 0, nil +} + +func (f *fakeCon) Write(_ []byte) (n int, err error) { + return 0, nil +} + +func (f *fakeCon) Close() error { + return nil +} + +func (f *fakeCon) LocalAddr() net.Addr { + return nil +} + +func (f *fakeCon) RemoteAddr() net.Addr { + return f.remoteAddr +} + +func (f *fakeCon) SetDeadline(_ time.Time) error { + return nil +} + +func (f *fakeCon) SetReadDeadline(_ time.Time) error { + return nil +} + +func (f *fakeCon) SetWriteDeadline(_ time.Time) error { + return nil +} + +var _ net.Conn = &fakeCon{} + +type fakeListener struct { + addr net.Addr + index int + err error + closed atomic.Bool + connErrPairs []connErrPair +} + +func (f *fakeListener) Accept() (net.Conn, error) { + if f.index < len(f.connErrPairs) { + index := f.index + connErr := f.connErrPairs[index] + f.index++ + return connErr.conn, connErr.err + } + for { + if f.closed.Load() { + return nil, fmt.Errorf("use of closed network connection") + } + } +} + +func (f *fakeListener) Close() error { + f.closed.Store(true) + return nil +} + +func (f *fakeListener) Addr() net.Addr { + return f.addr +} + +var _ net.Listener = &fakeListener{} + +func listenFuncFactory(listeners []*fakeListener) func(_ context.Context, network string, address string) (net.Listener, error) { + index := 0 + return func(_ context.Context, network string, address string) (net.Listener, error) { + if index < len(listeners) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + listener := listeners[index] + addr := &net.TCPAddr{ + IP: ParseIPSloppy(host), + Port: port, + } + if err != nil { + return nil, err + } + listener.addr = addr + index++ + + if listener.err != nil { + return nil, listener.err + } + return listener, nil + } + return nil, nil + } +} + +func TestMultiListen(t *testing.T) { + testCases := []struct { + name string + network string + addrs []string + fakeListeners []*fakeListener + errString string + }{ + { + name: "unsupported network", + network: "udp", + errString: "network 'udp' not supported", + }, + { + name: "no host", + network: "tcp", + errString: "no address provided to listen on", + }, + { + name: "valid", + network: "tcp", + addrs: []string{"127.0.0.1:12345"}, + fakeListeners: []*fakeListener{{connErrPairs: []connErrPair{}}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, tc.network, tc.addrs, listenFuncFactory(tc.fakeListeners)) + + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + if ml != nil { + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + } + }) + } +} + +func TestMultiListen_Addr(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, listenFuncFactory( + []*fakeListener{{}, {}, {}}, + )) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + if ml.Addr().String() != "10.10.10.10:5000" { + t.Errorf("Expected '10.10.10.10:5000' but got '%s'", ml.Addr().String()) + } + + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +} + +func TestMultiListen_Close(t *testing.T) { + testCases := []struct { + name string + addrs []string + runner func(listener net.Listener, acceptCalls int) error + fakeListeners []*fakeListener + acceptCalls int + errString string + }{ + { + name: "close", + addrs: []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{}, {}, {}}, + }, + { + name: "close with pending connections", + addrs: []string{"10.10.10.10:5001", "192.168.1.10:5002", "127.0.0.1:5003"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}, + }}}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}, + }, + }}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}, + }}, + }}, + }, + { + name: "close with no pending connections", + addrs: []string{"10.10.10.10:3001", "192.168.1.10:3002", "127.0.0.1:3003"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}, + }}}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}, + }, + }}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}, + }}, + }}, + acceptCalls: 3, + }, + { + name: "close on close", + addrs: []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + + err = ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{}, {}, {}}, + errString: "use of closed network connection", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", tc.addrs, listenFuncFactory(tc.fakeListeners)) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + err = tc.runner(ml, tc.acceptCalls) + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + + for _, f := range tc.fakeListeners { + if !f.closed.Load() { + t.Errorf("Expeted sub-listener to be closed") + } + } + }) + } +} + +func TestMultiListen_Accept(t *testing.T) { + testCases := []struct { + name string + addrs []string + runner func(listener net.Listener, acceptCalls int) (map[string]any, error) + fakeListeners []*fakeListener + acceptCalls int + expectedCons map[string]any + errString string + }{ + { + name: "accept connections", + addrs: []string{"10.10.10.10:3000", "192.168.1.103:4000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) (map[string]any, error) { + acceptedCons := make(map[string]any) + for i := 0; i < acceptCalls; i++ { + conn, err := ml.Accept() + if err != nil { + return nil, err + } + acceptedCons[conn.RemoteAddr().String()] = struct{}{} + } + err := ml.Close() + if err != nil { + return nil, err + } + return acceptedCons, nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}, + err: nil, + }}}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}, + err: nil, + }, + }}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}, + err: nil, + }}, + }}, + acceptCalls: 3, + expectedCons: map[string]any{ + "10.10.10.10:50001": struct{}{}, + "192.168.1.10:50002": struct{}{}, + "127.0.0.1:50003": struct{}{}, + }}, + { + name: "accept on closed listener", + addrs: []string{"10.10.10.10:3001", "192.168.1.10:3002", "127.0.0.1:3003"}, + runner: func(ml net.Listener, acceptCalls int) (map[string]any, error) { + acceptedCons := make(map[string]any) + err := ml.Close() + if err != nil { + return nil, err + } + for i := 0; i < acceptCalls; i++ { + conn, err := ml.Accept() + if err != nil { + return nil, err + } + acceptedCons[conn.RemoteAddr().String()] = struct{}{} + } + return acceptedCons, nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}, + }}}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}, + }, + }}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}, + }}, + }}, + acceptCalls: 1, + errString: "use of closed network connection", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", tc.addrs, listenFuncFactory(tc.fakeListeners)) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + acceptedCons, err := tc.runner(ml, tc.acceptCalls) + + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + + if !reflect.DeepEqual(acceptedCons, tc.expectedCons) { + // golang treats empty map as not equal + // (ref: https://github.com/golang/go/issues/16531) + if !(len(acceptedCons) == 0 && len(tc.expectedCons) == 0) { + t.Errorf("Expected %v; got %v", tc.expectedCons, acceptedCons) + } + } + }) + } +} + +func assertError(t *testing.T, errString string, err error) { + if err == nil && errString != "" { + t.Errorf("Expected error '%s' but got none", errString) + } + if err.Error() != errString { + t.Errorf("Expected error '%s' but got '%s'", errString, err.Error()) + } +} + +func assertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +}