Skip to content

Commit

Permalink
Expose Dialers inside Zk and Region
Browse files Browse the repository at this point in the history
This change allows for overriding the default dialers used by
the ZooKeeper client and the Region client to connect to their
respective servers. Proxy-aware dialers can be installed by people
who need them.
  • Loading branch information
marcinromaszewicz committed Apr 25, 2024
1 parent 1676ef7 commit c121a7a
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 24 deletions.
3 changes: 2 additions & 1 deletion admin_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

log "github.com/sirupsen/logrus"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
"github.com/tsuna/gohbase/region"
Expand Down Expand Up @@ -67,7 +68,7 @@ func newAdminClient(zkquorum string, options ...Option) AdminClient {
for _, option := range options {
option(c)
}
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout)
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout, c.zkDialer)
return c
}

Expand Down
33 changes: 29 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ import (
"sync"
"time"

gzk "github.com/go-zookeeper/zk"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
"modernc.org/b/v2"

"github.com/tsuna/gohbase/compression"
"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
"github.com/tsuna/gohbase/region"
"github.com/tsuna/gohbase/zk"
"google.golang.org/protobuf/proto"
"modernc.org/b/v2"
)

const (
Expand Down Expand Up @@ -97,9 +99,15 @@ type client struct {
closeOnce sync.Once

newRegionClientFn func(string, region.ClientType, int, time.Duration,
string, time.Duration, compression.Codec) hrpc.RegionClient
string, time.Duration, compression.Codec, region.Dialer) hrpc.RegionClient

compressionCodec compression.Codec

// zkDialer is passed through to Zk Connect() to configure custom connection settings
zkDialer gzk.Dialer
// regionDialer is passed into the region client to connect to hbase in a custom way,
// such as SOCKS proxy.
regionDialer region.Dialer
}

// NewClient creates a new HBase client.
Expand Down Expand Up @@ -140,7 +148,7 @@ func newClient(zkquorum string, options ...Option) *client {

//Have to create the zkClient after the Options have been set
//since the zkTimeout could be changed as an option
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout)
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout, c.zkDialer)

return c
}
Expand Down Expand Up @@ -268,6 +276,23 @@ func CompressionCodec(codec string) Option {
}
}

// ZooKeeperDialer will return an option to pass the given dialer function
// into the ZooKeeper client Connect() call, which allows for customizing
// network connections.
func ZooKeeperDialer(dialer gzk.Dialer) Option {
return func(c *client) {
c.zkDialer = dialer
}
}

// RegionDialer will return an option that uses the specified Dialer for
// connecting to region servers. This allows for connecting through proxies.
func RegionDialer(dialer region.Dialer) Option {
return func(c *client) {
c.regionDialer = dialer
}
}

// Close closes connections to hbase master and regionservers
func (c *client) Close() {
c.closeOnce.Do(func() {
Expand Down
2 changes: 2 additions & 0 deletions debug_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/region"
)
Expand All @@ -25,6 +26,7 @@ func TestDebugStateSanity(t *testing.T) {
defaultEffectiveUser,
region.DefaultReadTimeout,
client.compressionCodec,
nil,
)
newClientFn := func() hrpc.RegionClient {
return regClient
Expand Down
5 changes: 3 additions & 2 deletions mockrc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (
"sync/atomic"
"time"

"google.golang.org/protobuf/proto"

"github.com/tsuna/gohbase/compression"
"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
"github.com/tsuna/gohbase/region"
"google.golang.org/protobuf/proto"
)

type testClient struct {
Expand Down Expand Up @@ -177,7 +178,7 @@ func init() {

func newMockRegionClient(addr string, ctype region.ClientType, queueSize int,
flushInterval time.Duration, effectiveUser string,
readTimeout time.Duration, codec compression.Codec) hrpc.RegionClient {
readTimeout time.Duration, codec compression.Codec, dialer region.Dialer) hrpc.RegionClient {
m.Lock()
clients[addr]++
m.Unlock()
Expand Down
8 changes: 6 additions & 2 deletions region/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ import (
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
)

// ClientType is a type alias to represent the type of this region client
Expand Down Expand Up @@ -192,6 +193,9 @@ type client struct {

// compressor for cellblocks. if nil, then no compression
compressor *compressor

// dialer is used to connect to region servers in non-standard ways
dialer Dialer
}

// QueueRPC will add an rpc call to the queue for processing by the writer goroutine
Expand Down
20 changes: 17 additions & 3 deletions region/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@ import (
"github.com/tsuna/gohbase/hrpc"
)

// Dialer is used to connect to region servers. net.Dialer conforms to this
// interface, which is just the subset of it that we use.
type Dialer interface {
DialContext(ctx context.Context, net, addr string) (net.Conn, error)
}

// NewClient creates a new RegionClient.
func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.Duration,
effectiveUser string, readTimeout time.Duration, codec compression.Codec) hrpc.RegionClient {
effectiveUser string, readTimeout time.Duration, codec compression.Codec,
dialer Dialer) hrpc.RegionClient {
c := &client{
addr: addr,
ctype: ctype,
Expand All @@ -36,14 +43,21 @@ func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.
if codec != nil {
c.compressor = &compressor{Codec: codec}
}

if dialer != nil {
c.dialer = dialer
} else {
var d net.Dialer
c.dialer = &d
}

return c
}

func (c *client) Dial(ctx context.Context) error {
c.dialOnce.Do(func() {
var d net.Dialer
var err error
c.conn, err = d.DialContext(ctx, "tcp", c.addr)
c.conn, err = c.dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
c.fail(fmt.Errorf("failed to dial RegionServer: %s", err))
return
Expand Down
9 changes: 5 additions & 4 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import (
"time"

log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/codes"
"google.golang.org/protobuf/proto"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/internal/observability"
"github.com/tsuna/gohbase/region"
"github.com/tsuna/gohbase/zk"
"go.opentelemetry.io/otel/codes"
"google.golang.org/protobuf/proto"
)

// Constants
Expand Down Expand Up @@ -828,11 +829,11 @@ func (c *client) establishRegion(reg hrpc.RegionInfo, addr string) {
// master that we don't add to the cache
// TODO: consider combining this case with the regular regionserver path
client = c.newRegionClientFn(addr, c.clientType, c.rpcQueueSize, c.flushInterval,
c.effectiveUser, c.regionReadTimeout, nil)
c.effectiveUser, c.regionReadTimeout, nil, c.regionDialer)
} else {
client = c.clients.put(addr, reg, func() hrpc.RegionClient {
return c.newRegionClientFn(addr, c.clientType, c.rpcQueueSize, c.flushInterval,
c.effectiveUser, c.regionReadTimeout, c.compressionCodec)
c.effectiveUser, c.regionReadTimeout, c.compressionCodec, c.regionDialer)
})
}

Expand Down
11 changes: 6 additions & 5 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import (

"github.com/golang/mock/gomock"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/wrapperspb"
"modernc.org/b/v2"

"github.com/tsuna/gohbase/compression"
"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
Expand All @@ -29,15 +33,12 @@ import (
mockRegion "github.com/tsuna/gohbase/test/mock/region"
mockZk "github.com/tsuna/gohbase/test/mock/zk"
"github.com/tsuna/gohbase/zk"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/wrapperspb"
"modernc.org/b/v2"
)

func newRegionClientFn(addr string) func() hrpc.RegionClient {
return func() hrpc.RegionClient {
return newMockRegionClient(addr, region.RegionClient,
0, 0, "root", region.DefaultReadTimeout, nil)
0, 0, "root", region.DefaultReadTimeout, nil, nil)
}
}

Expand Down Expand Up @@ -301,7 +302,7 @@ func TestEstablishRegionDialFail(t *testing.T) {

newRegionClientFnCallCount := 0
c.newRegionClientFn = func(_ string, _ region.ClientType, _ int, _ time.Duration,
_ string, _ time.Duration, _ compression.Codec) hrpc.RegionClient {
_ string, _ time.Duration, _ compression.Codec, _ region.Dialer) hrpc.RegionClient {
var rc hrpc.RegionClient
if newRegionClientFnCallCount == 0 {
rc = rcFailDial
Expand Down
15 changes: 12 additions & 3 deletions zk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import (
log "github.com/sirupsen/logrus"

"github.com/go-zookeeper/zk"
"github.com/tsuna/gohbase/pb"
"google.golang.org/protobuf/proto"

"github.com/tsuna/gohbase/pb"
)

type logger struct{}
Expand Down Expand Up @@ -58,19 +59,27 @@ type Client interface {
type client struct {
zks []string
sessionTimeout time.Duration
dialer zk.Dialer
}

// NewClient establishes connection to zookeeper and returns the client
func NewClient(zkquorum string, st time.Duration) Client {
func NewClient(zkquorum string, st time.Duration, dialer zk.Dialer) Client {
return &client{
zks: strings.Split(zkquorum, ","),
sessionTimeout: st,
dialer: dialer,
}
}

// LocateResource returns address of the server for the specified resource.
func (c *client) LocateResource(resource ResourceName) (string, error) {
conn, _, err := zk.Connect(c.zks, c.sessionTimeout)
var conn *zk.Conn
var err error
if c.dialer != nil {
conn, _, err = zk.Connect(c.zks, c.sessionTimeout, zk.WithDialer(c.dialer))
} else {
conn, _, err = zk.Connect(c.zks, c.sessionTimeout)
}
if err != nil {
return "", fmt.Errorf("error connecting to ZooKeeper at %v: %s", c.zks, err)
}
Expand Down

0 comments on commit c121a7a

Please sign in to comment.