Skip to content

Commit

Permalink
feat(wit-bindgen-go): align flags decoding
Browse files Browse the repository at this point in the history
Flags encoding/decoding in Go is now rewritten
to use the correct binary encoding.

Flag structs now contain an addition ReadFromIndex
method that can be used for decoding. This makes
the generated code more testable.
  • Loading branch information
tchap authored and rvolosatovs committed Jun 28, 2024
1 parent 02c4350 commit 4fd3c34
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 55 deletions.
130 changes: 78 additions & 52 deletions crates/wit-bindgen-go/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use wit_bindgen_core::wit_parser::{
Variant, World, WorldItem, WorldKey,
};
use wit_bindgen_core::{uwrite, uwriteln, Source, TypeInfo};
use wrpc_introspect::{async_paths_ty, flag_repr, is_list_of, is_ty, rpc_func_name};
use wrpc_introspect::{async_paths_ty, is_list_of, is_ty, rpc_func_name};

use crate::{
to_go_ident, to_package_ident, to_upper_camel_case, Deps, GoWrpc, Identifier, InterfaceName,
Expand Down Expand Up @@ -737,37 +737,20 @@ impl InterfaceGenerator<'_> {
self.src.push_str(")");
}

fn print_read_flags(&mut self, ty: &Flags, reader: &str, name: &str) {
let fmt = self.deps.fmt();
let io = self.deps.io();

let repr = flag_repr(ty);
fn print_read_flags(&mut self, reader: &str, name: &str) {
let wrpc = self.deps.wrpc();

uwrite!(
self.src,
r#"func(r {io}.ByteReader) (*{name}, error) {{
v := &{name}{{}}
n, err := "#
r#"func(r {wrpc}.IndexReader) (*{name}, error) {{
v := {name}{{}}
if err := v.ReadFromIndex(r); err != nil {{
return nil, err
}}
return &v, nil
}}({reader})
"#
);
self.print_read_discriminant(repr, "r");
self.push_str("\n");
self.push_str("if err != nil {\n");
self.push_str("return nil, ");
self.push_str(fmt);
self.push_str(".Errorf(\"failed to read flag: %w\", err)\n");
self.push_str("}\n");
for (i, Flag { name, .. }) in ty.flags.iter().enumerate() {
if i > 64 {
break;
}
uwriteln!(self.src, "if n & (1 << {i}) > 0 {{");
self.push_str("v.");
self.push_str(&name.to_upper_camel_case());
self.push_str(" = true\n");
self.push_str("}\n");
}
self.push_str("return v, nil\n");
uwrite!(self.src, "}}({reader})");
}

fn print_read_enum(&mut self, ty: &Enum, reader: &str, name: &str) {
Expand Down Expand Up @@ -1328,8 +1311,8 @@ impl InterfaceGenerator<'_> {
TypeDefKind::Resource => self.print_read_string(reader),
TypeDefKind::Handle(Handle::Own(id)) => self.print_read_own(reader, *id),
TypeDefKind::Handle(Handle::Borrow(id)) => self.print_read_borrow(reader, *id),
TypeDefKind::Flags(ty) => {
self.print_read_flags(ty, reader, &name.expect("flag missing a name"));
TypeDefKind::Flags(_ty) => {
self.print_read_flags(reader, &name.expect("flag missing a name"));
}
TypeDefKind::Tuple(ty) => self.print_read_tuple(ty, reader, path),
TypeDefKind::Variant(ty) => {
Expand Down Expand Up @@ -3743,13 +3726,13 @@ func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) err
}

fn type_flags(&mut self, id: TypeId, _name: &str, ty: &Flags, docs: &Docs) {
let repr = flag_repr(ty);

let info = self.info(id);
if let Some(name) = self.name_of(id) {
let strings = self.deps.strings();
let wrpc = self.deps.wrpc();
let errors = self.deps.errors();

// Struct
self.godoc(docs);
uwriteln!(self.src, "type {name} struct {{");
for Flag { name, docs } in &ty.flags {
Expand All @@ -3759,6 +3742,7 @@ func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) err
}
self.push_str("}\n");

// String()
uwriteln!(self.src, "func (v *{name}) String() string {{");
uwriteln!(self.src, "flags := make([]string, 0, {})", ty.flags.len());
for Flag { name, .. } in &ty.flags {
Expand All @@ -3769,34 +3753,76 @@ func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) err
self.push_str("}\n");
}
uwriteln!(self.src, r#"return {strings}.Join(flags, " | ")"#);
self.push_str("}\n");
self.push_str("}\n\n");

// WriteToIndex()
let mut buf_len = ty.flags.len() / 8;
if ty.flags.len() % 8 > 0 {
buf_len += 1;
}

uwriteln!(
self.src,
"func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) error, error) {{"
r#"func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) error, error) {{
var p [{buf_len}]byte
"#
);
self.push_str("var n ");
self.int_repr(repr);
self.push_str("\n");

for (i, Flag { name, .. }) in ty.flags.iter().enumerate() {
self.push_str("if v.");
self.push_str(&name.to_upper_camel_case());
self.push_str(" {\n");
if i <= 64 {
uwriteln!(self.src, "n |= 1 << {i}");
} else {
let errors = self.deps.errors();
uwriteln!(
self.src,
r#"return nil, {errors}.New("encoding `{name}` flag value would overflow 64-bit integer, flags containing more than 64 members are not supported yet")"#
);
}
self.push_str("}\n");
uwriteln!(
self.src,
r#"{{
p[{}] |= 1 << {}
}}"#,
i / 8,
i % 8
);
}
self.push_str("return nil, ");
self.print_write_discriminant(repr, "n", "w");
self.push_str("\n");
self.push_str("}\n");
uwriteln!(
self.src,
r#"
_, err := w.Write(p[:])
return nil, err
}}
"#,
);

// ReadFromIndex()
uwrite!(
self.src,
r#"func (v *{name}) ReadFromIndex(r {wrpc}.IndexReader) error {{
var p [{buf_len}]byte
if _, err := r.Read(p[:]); err != nil {{
return err
}}
"#
);

for (i, Flag { name, .. }) in ty.flags.iter().enumerate() {
uwriteln!(
self.src,
"v.{} = p[{}] & (1 << {}) > 0",
name.to_upper_camel_case(),
i / 8,
i % 8
);
}

uwriteln!(
self.src,
r#"
if (p[{}] >> {}) > 0 {{
return {errors}.New("bit not associated with any flag is set")
}}"#,
buf_len - 1,
ty.flags.len() % 8,
);
self.push_str("return nil\n}\n");

// Error()
if info.error {
uwriteln!(
self.src,
Expand Down
3 changes: 3 additions & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
Expand Down
6 changes: 5 additions & 1 deletion tests/go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ go 1.22.2
require (
github.com/google/uuid v1.6.0
github.com/nats-io/nats-server/v2 v2.10.14
github.com/nats-io/nats.go v1.35.0
github.com/nats-io/nats.go v1.36.0
github.com/wrpc/wrpc/go v0.0.0-unpublished
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/minio/highwayhash v1.0.2 // indirect
github.com/nats-io/jwt/v2 v2.5.5 // indirect
github.com/nats-io/nkeys v0.4.7 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/time v0.5.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/wrpc/wrpc/go v0.0.0-unpublished => ../../go
14 changes: 12 additions & 2 deletions tests/go/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
Expand All @@ -8,16 +10,24 @@ github.com/nats-io/jwt/v2 v2.5.5 h1:ROfXb50elFq5c9+1ztaUbdlrArNFl2+fQWP6B8HGEq4=
github.com/nats-io/jwt/v2 v2.5.5/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A=
github.com/nats-io/nats-server/v2 v2.10.14 h1:98gPJFOAO2vLdM0gogh8GAiHghwErrSLhugIqzRC+tk=
github.com/nats-io/nats-server/v2 v2.10.14/go.mod h1:a0TwOVBJZz6Hwv7JH2E4ONdpyFk9do0C18TEwxnHdRk=
github.com/nats-io/nats.go v1.35.0 h1:XFNqNM7v5B+MQMKqVGAyHwYhyKb48jrenXNxIU20ULk=
github.com/nats-io/nats.go v1.35.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
github.com/nats-io/nats.go v1.36.0 h1:suEUPuWzTSse/XhESwqLxXGuj8vGRuPRoG7MoRN/qyU=
github.com/nats-io/nats.go v1.36.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
63 changes: 63 additions & 0 deletions tests/go/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//go:generate $WIT_BINDGEN_WRPC go --gofmt=false --world types --out-dir bindings/types --package github.com/wrpc/wrpc/tests/go/bindings/types ../wit

package integration_test

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"

wrpc "github.com/wrpc/wrpc/go"
"github.com/wrpc/wrpc/tests/go/bindings/types/wrpc_test/integration/get_types"
)

type indexReader struct {
*bytes.Buffer
}

func (r *indexReader) Index(path ...uint32) (wrpc.IndexReader, error) {
panic("not implemented")
}

func TestTypes_Flags(t *testing.T) {
t.Run("WriteToIndex", func(t *testing.T) {
check := assert.New(t)

flags := get_types.FeatureFlags{
ShowA: true,
ShowC: true,
}
var b bytes.Buffer
_, err := flags.WriteToIndex(&b)
check.NoError(err)
check.Equal([]byte{0b00000101}, b.Bytes())
})

t.Run("ReadFromIndex", func(t *testing.T) {
check := assert.New(t)

var flags get_types.FeatureFlags
err := flags.ReadFromIndex(&indexReader{
Buffer: bytes.NewBuffer([]byte{0b00000101}),
})
check.NoError(err)
check.Equal(get_types.FeatureFlags{
ShowA: true,
ShowC: true,
}, flags)

t.Run("invalid bit set", func(t *testing.T) {
check := assert.New(t)

var flags get_types.FeatureFlags
err := flags.ReadFromIndex(&indexReader{
Buffer: bytes.NewBuffer([]byte{0b10000000}),
})
if check.Error(err) {
check.Equal("bit not associated with any flag is set", err.Error())
}
check.Zero(flags)
})
})
}
17 changes: 17 additions & 0 deletions tests/wit/test.wit
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,20 @@ world resources-client {
bar: func(v: borrow<foo>) -> u64;
}
}

interface get-types {
flags feature-flags {
show-a,
show-b,
show-c,
show-d,
show-e,
show-f,
}

get-features: func() -> feature-flags;
}

world types {
import get-types;
}

0 comments on commit 4fd3c34

Please sign in to comment.