diff --git a/cmd/raiju/raiju.go b/cmd/raiju/raiju.go index 73fd9b9..9336de7 100644 --- a/cmd/raiju/raiju.go +++ b/cmd/raiju/raiju.go @@ -19,7 +19,6 @@ import ( "github.com/nyonson/raiju" "github.com/nyonson/raiju/lightning" - "github.com/nyonson/raiju/lnd" "github.com/nyonson/raiju/view" ) @@ -118,7 +117,7 @@ func main() { return err } - c := lnd.New(services.Client, services.Client, services.Router, *network) + c := lightning.NewLndClient(services.Client, services.Client, services.Router, *network) f, err := parseFees(*liquidityThresholds, *liquidityFees, *liquidityStickiness) if err != nil { return err @@ -185,7 +184,7 @@ func main() { return err } - c := lnd.New(services.Client, services.Client, services.Router, *network) + c := lightning.NewLndClient(services.Client, services.Client, services.Router, *network) f, err := parseFees(*liquidityThresholds, *liquidityFees, *liquidityStickiness) if err != nil { return err @@ -261,7 +260,7 @@ func main() { return err } - c := lnd.New(services.Client, services.Client, services.Router, *network) + c := lightning.NewLndClient(services.Client, services.Client, services.Router, *network) f, err := parseFees(*liquidityThresholds, *liquidityFees, *liquidityStickiness) if err != nil { return err @@ -325,7 +324,7 @@ func main() { return err } - c := lnd.New(services.Client, services.Client, services.Router, *network) + c := lightning.NewLndClient(services.Client, services.Client, services.Router, *network) f, err := parseFees(*liquidityThresholds, *liquidityFees, *liquidityStickiness) if err != nil { return err @@ -368,7 +367,7 @@ func main() { return err } - c := lnd.New(services.Client, services.Client, services.Router, *network) + c := lightning.NewLndClient(services.Client, services.Client, services.Router, *network) f, err := parseFees(*liquidityThresholds, *liquidityFees, *liquidityStickiness) if err != nil { return err diff --git a/lightning/lightning.go b/lightning/lightning.go index 0ed6ce9..b33eaca 100644 --- a/lightning/lightning.go +++ b/lightning/lightning.go @@ -6,7 +6,7 @@ import ( "time" ) -//go:generate gotests -w -exported . +//go:generate gotests -w -exported lightning.go // Satoshi unit of bitcoin. type Satoshi int64 diff --git a/lnd/lnd.go b/lightning/lnd.go similarity index 65% rename from lnd/lnd.go rename to lightning/lnd.go index fd5d8c6..b4393c6 100644 --- a/lnd/lnd.go +++ b/lightning/lnd.go @@ -1,5 +1,4 @@ -// Hook raiju up to LND. -package lnd +package lightning import ( "context" @@ -21,10 +20,9 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" - "github.com/nyonson/raiju/lightning" ) -//go:generate gotests -w -exported . +//go:generate gotests -w -exported lnd.go //go:generate moq -stub -skip-ensure -out lnd_mock_test.go . channeler router invoicer // channeler is the minimum channel requirements from LND. @@ -51,9 +49,9 @@ type invoicer interface { AddInvoice(ctx context.Context, in *invoicesrpc.AddInvoiceData) (lntypes.Hash, string, error) } -// New Lightning instance. -func New(c channeler, i invoicer, r router, network string) Lnd { - return Lnd{ +// NewLndClient backed by a single LND lightning node. +func NewLndClient(c channeler, i invoicer, r router, network string) LndClient { + return LndClient{ c: c, i: i, r: r, @@ -61,8 +59,8 @@ func New(c channeler, i invoicer, r router, network string) Lnd { } } -// Lnd client backed by LND node. -type Lnd struct { +// LndClient client backed by LND node. +type LndClient struct { c channeler r router i invoicer @@ -70,33 +68,33 @@ type Lnd struct { } // GetInfo of local node. -func (l Lnd) GetInfo(ctx context.Context) (*lightning.Info, error) { +func (l LndClient) GetInfo(ctx context.Context) (*Info, error) { i, err := l.c.GetInfo(ctx) if err != nil { - return &lightning.Info{}, err + return &Info{}, err } - info := lightning.Info{ - PubKey: lightning.PubKey(hex.EncodeToString(i.IdentityPubkey[:])), + info := Info{ + PubKey: PubKey(hex.EncodeToString(i.IdentityPubkey[:])), } return &info, nil } // DescribeGraph of the Lightning Network. -func (l Lnd) DescribeGraph(ctx context.Context) (*lightning.Graph, error) { +func (l LndClient) DescribeGraph(ctx context.Context) (*Graph, error) { g, err := l.c.DescribeGraph(ctx, false) if err != nil { - return &lightning.Graph{}, err + return &Graph{}, err } // marshall nodes - nodes := make([]lightning.Node, len(g.Nodes)) + nodes := make([]Node, len(g.Nodes)) for i, n := range g.Nodes { - nodes[i] = lightning.Node{ - PubKey: lightning.PubKey(n.PubKey.String()), + nodes[i] = Node{ + PubKey: PubKey(n.PubKey.String()), Alias: n.Alias, Updated: n.LastUpdate, Addresses: n.Addresses, @@ -104,16 +102,16 @@ func (l Lnd) DescribeGraph(ctx context.Context) (*lightning.Graph, error) { } // marshall edges - edges := make([]lightning.Edge, len(g.Edges)) + edges := make([]Edge, len(g.Edges)) for i, e := range g.Edges { - edges[i] = lightning.Edge{ - Capacity: lightning.Satoshi(e.Capacity.ToUnit(btcutil.AmountSatoshi)), - Node1: lightning.PubKey(e.Node1.String()), - Node2: lightning.PubKey(e.Node2.String()), + edges[i] = Edge{ + Capacity: Satoshi(e.Capacity.ToUnit(btcutil.AmountSatoshi)), + Node1: PubKey(e.Node1.String()), + Node2: PubKey(e.Node2.String()), } } - graph := &lightning.Graph{ + graph := &Graph{ Nodes: nodes, Edges: edges, } @@ -122,42 +120,42 @@ func (l Lnd) DescribeGraph(ctx context.Context) (*lightning.Graph, error) { } // GetChannel with ID. -func (l Lnd) GetChannel(ctx context.Context, channelID lightning.ChannelID) (lightning.Channel, error) { +func (l LndClient) GetChannel(ctx context.Context, channelID ChannelID) (Channel, error) { // returns a channel edge which doesn't have liquidity info ce, err := l.c.GetChanInfo(ctx, uint64(channelID)) if err != nil { - return lightning.Channel{}, err + return Channel{}, err } local, err := l.c.GetInfo(ctx) if err != nil { - return lightning.Channel{}, err + return Channel{}, err } // figure out if which node is local and which is remote remotePubkey := ce.Node1 // FeeRateMilliMsat is a weird name - localFee := lightning.FeePPM(ce.Node2Policy.FeeRateMilliMsat) + localFee := FeePPM(ce.Node2Policy.FeeRateMilliMsat) if local.IdentityPubkey == remotePubkey { remotePubkey = ce.Node2 - localFee = lightning.FeePPM(ce.Node1Policy.FeeRateMilliMsat) + localFee = FeePPM(ce.Node1Policy.FeeRateMilliMsat) } remote, err := l.c.GetNodeInfo(ctx, remotePubkey, false) if err != nil { - return lightning.Channel{}, err + return Channel{}, err } - c := lightning.Channel{ - Edge: lightning.Edge{ - Capacity: lightning.Satoshi(ce.Capacity.ToUnit(btcutil.AmountSatoshi)), - Node1: lightning.PubKey(ce.Node1.String()), - Node2: lightning.PubKey(ce.Node2.String()), + c := Channel{ + Edge: Edge{ + Capacity: Satoshi(ce.Capacity.ToUnit(btcutil.AmountSatoshi)), + Node1: PubKey(ce.Node1.String()), + Node2: PubKey(ce.Node2.String()), }, - ChannelID: lightning.ChannelID(ce.ChannelID), + ChannelID: ChannelID(ce.ChannelID), LocalFee: localFee, - RemoteNode: lightning.Node{ - PubKey: lightning.PubKey(remote.PubKey.String()), + RemoteNode: Node{ + PubKey: PubKey(remote.PubKey.String()), Alias: remote.Alias, Updated: remote.LastUpdate, Addresses: remote.Addresses, @@ -167,13 +165,13 @@ func (l Lnd) GetChannel(ctx context.Context, channelID lightning.ChannelID) (lig // get local and remote liquidity from the list channels call cs, err := l.c.ListChannels(ctx, false, false) if err != nil { - return lightning.Channel{}, err + return Channel{}, err } for _, ci := range cs { - if lightning.ChannelID(ci.ChannelID) == channelID { - c.LocalBalance = lightning.Satoshi(ci.LocalBalance.ToUnit(btcutil.AmountSatoshi)) - c.RemoteBalance = lightning.Satoshi(ci.RemoteBalance.ToUnit(btcutil.AmountSatoshi)) + if ChannelID(ci.ChannelID) == channelID { + c.LocalBalance = Satoshi(ci.LocalBalance.ToUnit(btcutil.AmountSatoshi)) + c.RemoteBalance = Satoshi(ci.RemoteBalance.ToUnit(btcutil.AmountSatoshi)) c.Private = ci.Private } } @@ -182,7 +180,7 @@ func (l Lnd) GetChannel(ctx context.Context, channelID lightning.ChannelID) (lig } // ListChannels of local node. -func (l Lnd) ListChannels(ctx context.Context) (lightning.Channels, error) { +func (l LndClient) ListChannels(ctx context.Context) (Channels, error) { channelInfos, err := l.c.ListChannels(ctx, false, false) if err != nil { return nil, err @@ -193,7 +191,7 @@ func (l Lnd) ListChannels(ctx context.Context) (lightning.Channels, error) { return nil, err } - channels := make([]lightning.Channel, len(channelInfos)) + channels := make([]Channel, len(channelInfos)) for i, ci := range channelInfos { ce, err := l.c.GetChanInfo(ctx, ci.ChannelID) if err != nil { @@ -202,10 +200,10 @@ func (l Lnd) ListChannels(ctx context.Context) (lightning.Channels, error) { // figure out if which node is local and which is remote remotePubkey := ce.Node1 - localFee := lightning.FeePPM(ce.Node2Policy.FeeRateMilliMsat) + localFee := FeePPM(ce.Node2Policy.FeeRateMilliMsat) if local.IdentityPubkey == remotePubkey { remotePubkey = ce.Node2 - localFee = lightning.FeePPM(ce.Node1Policy.FeeRateMilliMsat) + localFee = FeePPM(ce.Node1Policy.FeeRateMilliMsat) } remote, err := l.c.GetNodeInfo(ctx, remotePubkey, false) @@ -213,18 +211,18 @@ func (l Lnd) ListChannels(ctx context.Context) (lightning.Channels, error) { return nil, err } - channels[i] = lightning.Channel{ - Edge: lightning.Edge{ - Capacity: lightning.Satoshi(ce.Capacity.ToUnit(btcutil.AmountSatoshi)), - Node1: lightning.PubKey(ce.Node1.String()), - Node2: lightning.PubKey(ce.Node2.String()), + channels[i] = Channel{ + Edge: Edge{ + Capacity: Satoshi(ce.Capacity.ToUnit(btcutil.AmountSatoshi)), + Node1: PubKey(ce.Node1.String()), + Node2: PubKey(ce.Node2.String()), }, - ChannelID: lightning.ChannelID(ci.ChannelID), - LocalBalance: lightning.Satoshi(ci.LocalBalance.ToUnit(btcutil.AmountSatoshi)), + ChannelID: ChannelID(ci.ChannelID), + LocalBalance: Satoshi(ci.LocalBalance.ToUnit(btcutil.AmountSatoshi)), LocalFee: localFee, - RemoteBalance: lightning.Satoshi(ci.RemoteBalance.ToUnit(btcutil.AmountSatoshi)), - RemoteNode: lightning.Node{ - PubKey: lightning.PubKey(remote.PubKey.String()), + RemoteBalance: Satoshi(ci.RemoteBalance.ToUnit(btcutil.AmountSatoshi)), + RemoteNode: Node{ + PubKey: PubKey(remote.PubKey.String()), Alias: remote.Alias, Updated: remote.LastUpdate, Addresses: remote.Addresses, @@ -237,7 +235,7 @@ func (l Lnd) ListChannels(ctx context.Context) (lightning.Channels, error) { } // SetFees for channel with rate in ppm. -func (l Lnd) SetFees(ctx context.Context, channelID lightning.ChannelID, fee lightning.FeePPM) error { +func (l LndClient) SetFees(ctx context.Context, channelID ChannelID, fee FeePPM) error { ce, err := l.c.GetChanInfo(ctx, uint64(channelID)) if err != nil { return err @@ -258,16 +256,16 @@ func (l Lnd) SetFees(ctx context.Context, channelID lightning.ChannelID, fee lig } // AddInvoice of amount. -func (l Lnd) AddInvoice(ctx context.Context, amount lightning.Satoshi) (lightning.Invoice, error) { +func (l LndClient) AddInvoice(ctx context.Context, amount Satoshi) (Invoice, error) { in := &invoicesrpc.AddInvoiceData{ Value: lnwire.NewMSatFromSatoshis(btcutil.Amount(amount)), } _, invoice, err := l.i.AddInvoice(ctx, in) - return lightning.Invoice(invoice), err + return Invoice(invoice), err } // SendPayment to pay for invoice. -func (l Lnd) SendPayment(ctx context.Context, invoice lightning.Invoice, outChannelID lightning.ChannelID, lastHopPubKey lightning.PubKey, maxFee lightning.FeePPM) (lightning.Satoshi, error) { +func (l LndClient) SendPayment(ctx context.Context, invoice Invoice, outChannelID ChannelID, lastHopPubKey PubKey, maxFee FeePPM) (Satoshi, error) { lhpk, err := route.NewVertexFromStr(string(lastHopPubKey)) if err != nil { return 0, err @@ -301,7 +299,7 @@ func (l Lnd) SendPayment(ctx context.Context, invoice lightning.Invoice, outChan select { case s := <-status: if s.State == lnrpc.Payment_SUCCEEDED { - return lightning.Satoshi(s.Fee.ToSatoshis()), nil + return Satoshi(s.Fee.ToSatoshis()), nil } case e := <-error: return 0, fmt.Errorf("error paying invoice: %w", e) @@ -310,8 +308,8 @@ func (l Lnd) SendPayment(ctx context.Context, invoice lightning.Invoice, outChan } // SubscribeChannelUpdates signals when a channel's liquidity changes. -func (l Lnd) SubscribeChannelUpdates(ctx context.Context) (<-chan lightning.Channels, <-chan error, error) { - cc := make(chan lightning.Channels) +func (l LndClient) SubscribeChannelUpdates(ctx context.Context) (<-chan Channels, <-chan error, error) { + cc := make(chan Channels) ec := make(chan error) htlcs, errors, err := l.r.SubscribeHtlcEvents(ctx) @@ -324,10 +322,10 @@ func (l Lnd) SubscribeChannelUpdates(ctx context.Context) (<-chan lightning.Chan for { select { case h := <-htlcs: - channels := make(lightning.Channels, 0) + channels := make(Channels, 0) if h.GetIncomingChannelId() != 0 { - c, err := l.GetChannel(ctx, lightning.ChannelID(h.GetIncomingChannelId())) + c, err := l.GetChannel(ctx, ChannelID(h.GetIncomingChannelId())) if err != nil { ec <- err break @@ -336,7 +334,7 @@ func (l Lnd) SubscribeChannelUpdates(ctx context.Context) (<-chan lightning.Chan } if h.GetOutgoingChannelId() != 0 { - c, err := l.GetChannel(ctx, lightning.ChannelID(h.GetOutgoingChannelId())) + c, err := l.GetChannel(ctx, ChannelID(h.GetOutgoingChannelId())) if err != nil { ec <- err break @@ -355,7 +353,7 @@ func (l Lnd) SubscribeChannelUpdates(ctx context.Context) (<-chan lightning.Chan } // ForwardingHistory of node since the time give, capped at 50,000 events. -func (l Lnd) ForwardingHistory(ctx context.Context, since time.Time) ([]lightning.Forward, error) { +func (l LndClient) ForwardingHistory(ctx context.Context, since time.Time) ([]Forward, error) { maxEvents := 50000 req := lndclient.ForwardingHistoryRequest{ StartTime: since, @@ -372,12 +370,12 @@ func (l Lnd) ForwardingHistory(ctx context.Context, since time.Time) ([]lightnin return nil, errors.New("pulled too many events, lower time window") } - forwards := make([]lightning.Forward, 0) + forwards := make([]Forward, 0) for _, f := range res.Events { - forward := lightning.Forward{ + forward := Forward{ Timestamp: f.Timestamp, - ChannelIn: lightning.ChannelID(f.ChannelIn), - ChannelOut: lightning.ChannelID(f.ChannelOut), + ChannelIn: ChannelID(f.ChannelIn), + ChannelOut: ChannelID(f.ChannelOut), } forwards = append(forwards, forward) } diff --git a/lnd/lnd_mock_test.go b/lightning/lnd_mock_test.go similarity index 99% rename from lnd/lnd_mock_test.go rename to lightning/lnd_mock_test.go index 3dc9682..50bdc77 100644 --- a/lnd/lnd_mock_test.go +++ b/lightning/lnd_mock_test.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package lnd +package lightning import ( "context" diff --git a/lnd/lnd_test.go b/lightning/lnd_test.go similarity index 57% rename from lnd/lnd_test.go rename to lightning/lnd_test.go index f548dd9..d607af8 100644 --- a/lnd/lnd_test.go +++ b/lightning/lnd_test.go @@ -1,5 +1,4 @@ -// Hook raiju up to LND. -package lnd +package lightning import ( "context" @@ -8,47 +7,9 @@ import ( "time" "github.com/lightninglabs/lndclient" - "github.com/nyonson/raiju/lightning" ) -func TestNew(t *testing.T) { - type args struct { - c channeler - i invoicer - r router - network string - } - tests := []struct { - name string - args args - want Lnd - }{ - { - name: "happy init", - args: args{ - c: nil, - i: nil, - r: nil, - network: "mainnet", - }, - want: Lnd{ - c: nil, - r: nil, - i: nil, - network: "mainnet", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := New(tt.args.c, tt.args.i, tt.args.r, tt.args.network); !reflect.DeepEqual(got, tt.want) { - t.Errorf("New() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestLnd_GetInfo(t *testing.T) { +func TestLndClient_GetInfo(t *testing.T) { var pubKey [33]byte type fields struct { @@ -63,7 +24,7 @@ func TestLnd_GetInfo(t *testing.T) { name string fields fields args args - want *lightning.Info + want *Info wantErr bool }{ { @@ -80,7 +41,7 @@ func TestLnd_GetInfo(t *testing.T) { i: &invoicerMock{}, }, args: args{}, - want: &lightning.Info{ + want: &Info{ PubKey: "000000000000000000000000000000000000000000000000000000000000000000", }, wantErr: false, @@ -88,7 +49,7 @@ func TestLnd_GetInfo(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ + l := LndClient{ c: tt.fields.c, r: tt.fields.r, i: tt.fields.i, @@ -105,86 +66,74 @@ func TestLnd_GetInfo(t *testing.T) { } } -func TestLnd_DescribeGraph(t *testing.T) { - type fields struct { - c channeler - r router - i invoicer - } +func TestNewLndClient(t *testing.T) { type args struct { - ctx context.Context + c channeler + i invoicer + r router + network string } tests := []struct { - name string - fields fields - args args - want *lightning.Graph - wantErr bool + name string + args args + want LndClient }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, - } - got, err := l.DescribeGraph(tt.args.ctx) - if (err != nil) != tt.wantErr { - t.Errorf("Lnd.DescribeGraph() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.DescribeGraph() = %v, want %v", got, tt.want) + if got := NewLndClient(tt.args.c, tt.args.i, tt.args.r, tt.args.network); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewLndClient() = %v, want %v", got, tt.want) } }) } } -func TestLnd_GetChannel(t *testing.T) { +func TestLndClient_DescribeGraph(t *testing.T) { type fields struct { - c channeler - r router - i invoicer + c channeler + r router + i invoicer + network string } type args struct { - ctx context.Context - channelID lightning.ChannelID + ctx context.Context } tests := []struct { name string fields fields args args - want lightning.Channel + want *Graph wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, } - got, err := l.GetChannel(tt.args.ctx, tt.args.channelID) + got, err := l.DescribeGraph(tt.args.ctx) if (err != nil) != tt.wantErr { - t.Errorf("Lnd.GetChannel() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.DescribeGraph() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.GetChannel() = %v, want %v", got, tt.want) + t.Errorf("LndClient.DescribeGraph() = %v, want %v", got, tt.want) } }) } } -func TestLnd_ListChannels(t *testing.T) { +func TestLndClient_ListChannels(t *testing.T) { type fields struct { - c channeler - r router - i invoicer + c channeler + r router + i invoicer + network string } type args struct { ctx context.Context @@ -193,40 +142,42 @@ func TestLnd_ListChannels(t *testing.T) { name string fields fields args args - want lightning.Channels + want Channels wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, } got, err := l.ListChannels(tt.args.ctx) if (err != nil) != tt.wantErr { - t.Errorf("Lnd.ListChannels() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.ListChannels() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.ListChannels() = %v, want %v", got, tt.want) + t.Errorf("LndClient.ListChannels() = %v, want %v", got, tt.want) } }) } } -func TestLnd_SetFees(t *testing.T) { +func TestLndClient_SetFees(t *testing.T) { type fields struct { - c channeler - r router - i invoicer + c channeler + r router + i invoicer + network string } type args struct { ctx context.Context - channelID lightning.ChannelID - fee lightning.FeePPM + channelID ChannelID + fee FeePPM } tests := []struct { name string @@ -238,99 +189,108 @@ func TestLnd_SetFees(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, } if err := l.SetFees(tt.args.ctx, tt.args.channelID, tt.args.fee); (err != nil) != tt.wantErr { - t.Errorf("Lnd.SetFees() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.SetFees() error = %v, wantErr %v", err, tt.wantErr) } }) } } -func TestLnd_AddInvoice(t *testing.T) { +func TestLndClient_AddInvoice(t *testing.T) { type fields struct { - c channeler - r router - i invoicer + c channeler + r router + i invoicer + network string } type args struct { ctx context.Context - amount lightning.Satoshi + amount Satoshi } tests := []struct { name string fields fields args args - want lightning.Invoice + want Invoice wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, } got, err := l.AddInvoice(tt.args.ctx, tt.args.amount) if (err != nil) != tt.wantErr { - t.Errorf("Lnd.AddInvoice() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.AddInvoice() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.AddInvoice() = %v, want %v", got, tt.want) + if got != tt.want { + t.Errorf("LndClient.AddInvoice() = %v, want %v", got, tt.want) } }) } } -func TestLnd_ForwardingHistory(t *testing.T) { +func TestLndClient_SendPayment(t *testing.T) { type fields struct { - c channeler - r router - i invoicer + c channeler + r router + i invoicer + network string } type args struct { - ctx context.Context - since time.Time + ctx context.Context + invoice Invoice + outChannelID ChannelID + lastHopPubKey PubKey + maxFee FeePPM } tests := []struct { name string fields fields args args - want []lightning.Forward + want Satoshi wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, } - got, err := l.ForwardingHistory(tt.args.ctx, tt.args.since) + got, err := l.SendPayment(tt.args.ctx, tt.args.invoice, tt.args.outChannelID, tt.args.lastHopPubKey, tt.args.maxFee) if (err != nil) != tt.wantErr { - t.Errorf("Lnd.ForwardingHistory() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.SendPayment() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.ForwardingHistory() = %v, want %v", got, tt.want) + if got != tt.want { + t.Errorf("LndClient.SendPayment() = %v, want %v", got, tt.want) } }) } } -func TestLnd_SubscribeChannelUpdates(t *testing.T) { +func TestLndClient_SubscribeChannelUpdates(t *testing.T) { type fields struct { - c channeler - r router - i invoicer + c channeler + r router + i invoicer + network string } type args struct { ctx context.Context @@ -339,7 +299,7 @@ func TestLnd_SubscribeChannelUpdates(t *testing.T) { name string fields fields args args - want <-chan lightning.Channels + want <-chan Channels want1 <-chan error wantErr bool }{ @@ -347,27 +307,28 @@ func TestLnd_SubscribeChannelUpdates(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ - c: tt.fields.c, - r: tt.fields.r, - i: tt.fields.i, + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, } got, got1, err := l.SubscribeChannelUpdates(tt.args.ctx) if (err != nil) != tt.wantErr { - t.Errorf("Lnd.SubscribeChannelUpdates() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.SubscribeChannelUpdates() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.SubscribeChannelUpdates() got = %v, want %v", got, tt.want) + t.Errorf("LndClient.SubscribeChannelUpdates() got = %v, want %v", got, tt.want) } if !reflect.DeepEqual(got1, tt.want1) { - t.Errorf("Lnd.SubscribeChannelUpdates() got1 = %v, want %v", got1, tt.want1) + t.Errorf("LndClient.SubscribeChannelUpdates() got1 = %v, want %v", got1, tt.want1) } }) } } -func TestLnd_SendPayment(t *testing.T) { +func TestLndClient_ForwardingHistory(t *testing.T) { type fields struct { c channeler r router @@ -375,36 +336,73 @@ func TestLnd_SendPayment(t *testing.T) { network string } type args struct { - ctx context.Context - invoice lightning.Invoice - outChannelID lightning.ChannelID - lastHopPubKey lightning.PubKey - maxFee lightning.FeePPM + ctx context.Context + since time.Time } tests := []struct { name string fields fields args args - want lightning.Satoshi + want []Forward wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := Lnd{ + l := LndClient{ c: tt.fields.c, r: tt.fields.r, i: tt.fields.i, network: tt.fields.network, } - got, err := l.SendPayment(tt.args.ctx, tt.args.invoice, tt.args.outChannelID, tt.args.lastHopPubKey, tt.args.maxFee) + got, err := l.ForwardingHistory(tt.args.ctx, tt.args.since) + if (err != nil) != tt.wantErr { + t.Errorf("LndClient.ForwardingHistory() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("LndClient.ForwardingHistory() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLndClient_GetChannel(t *testing.T) { + type fields struct { + c channeler + r router + i invoicer + network string + } + type args struct { + ctx context.Context + channelID ChannelID + } + tests := []struct { + name string + fields fields + args args + want Channel + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := LndClient{ + c: tt.fields.c, + r: tt.fields.r, + i: tt.fields.i, + network: tt.fields.network, + } + got, err := l.GetChannel(tt.args.ctx, tt.args.channelID) if (err != nil) != tt.wantErr { - t.Errorf("Lnd.SendPayment() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("LndClient.GetChannel() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Lnd.SendPayment() = %v, want %v", got, tt.want) + t.Errorf("LndClient.GetChannel() = %v, want %v", got, tt.want) } }) } diff --git a/raiju.go b/raiju.go index c1fef11..154ae9c 100644 --- a/raiju.go +++ b/raiju.go @@ -11,7 +11,7 @@ import ( "github.com/nyonson/raiju/lightning" ) -//go:generate gotests -w -exported . +//go:generate gotests -w -exported raiju.go //go:generate moq -stub -skip-ensure -out raiju_mock_test.go . lightninger type lightninger interface {