diff --git a/pool/interface.go b/pool/interface.go index f13ed5a..7619119 100644 --- a/pool/interface.go +++ b/pool/interface.go @@ -1,19 +1,80 @@ package pool +import ( + "bytes" + "io" + "sync" +) + type Pool[T any] interface { Get() T Put(T) } +// ByteBuffer is satisfied by [*pool.Buffer] and [*bytes.Buffer]. +// +// Note, we can't include Reset and Grow as our implementations return an error while a [*bytes.Buffer] does not. +type ByteBuffer interface { + Len() int + Cap() int + Bytes() []byte + WriteRune(rune) (int, error) + WriteString(string) (int, error) + + io.WriterTo + io.ByteWriter + io.ReadWriter + + io.ReaderFrom + io.ByteReader + io.RuneReader +} + type WithPutError[T any] interface { Get() T Put(T) error } +func (b BufferFactoryInterfaceCompat) Put(buf *Buffer) { + _ = b.BufferFactory.Put(buf) +} + type BufferFactoryInterfaceCompat struct { BufferFactory } -func (b BufferFactoryInterfaceCompat) Put(buf *Buffer) { - _ = b.BufferFactory.Put(buf) +type BufferFactoryByteBufferCompat struct { + BufferFactory +} + +func (bf BufferFactoryByteBufferCompat) Put(buf ByteBuffer) { + if b, ok := buf.(*Buffer); ok { + err := bf.BufferFactory.Put(b) + if err != nil { + panic(err) + } + return + } + if b, ok := buf.(*bytes.Buffer); ok { + newB := &Buffer{ + o: &sync.Once{}, + Buffer: b, + } + _ = bf.BufferFactory.Put(newB) + return + } + // unfortunately this compatibility shim cannot be used with any other types implementing ByteBuffer + // this is because we can't wrap them in a *Buffer + panic("invalid type, need *pool.Buffer or *bytes.Buffer") +} + +func (bf BufferFactoryByteBufferCompat) Get() ByteBuffer { + b := bf.BufferFactory.Get() + return ByteBuffer(b) } + +var ( + _ ByteBuffer = (*Buffer)(nil) + _ ByteBuffer = (*bytes.Buffer)(nil) + _ Pool[ByteBuffer] = (*BufferFactoryByteBufferCompat)(nil) +) diff --git a/pool/interface_test.go b/pool/interface_test.go index b36c712..44b2951 100644 --- a/pool/interface_test.go +++ b/pool/interface_test.go @@ -6,6 +6,10 @@ import ( "testing" ) +type othaBuffa struct { + ByteBuffer +} + // ensure compatibility with interface func TestInterfaces(t *testing.T) { t.Parallel() @@ -65,4 +69,46 @@ func TestInterfaces(t *testing.T) { } }) + t.Run("BufferFactoryByteBufferCompat", func(t *testing.T) { + t.Parallel() + bf := BufferFactoryByteBufferCompat{NewBufferFactory()} + b := bf.Get() + if _, err := b.WriteString("test"); err != nil { + t.Fatal(err) + } + bf.Put(b) + b = bf.Get() + if b.Len() != 0 { + t.Fatal("buffer not reset") + } + foreign := &bytes.Buffer{} + foreign.WriteString("test") + bf.Put(foreign) + if foreign.Len() != 0 { + t.Fatal("buffer not reset") + } + foreignGot := bf.Get() + if foreignGot.Len() != 0 { + t.Fatal("buffer not reset") + } + bf.Put(foreignGot) + t.Run("must panic after wrapped and put twice", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("panic expected") + } + }() + bf.Put(foreignGot) + }) + t.Run("must panic on invalid type", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("panic expected") + } + }() + bf.Put(&othaBuffa{}) + }) + + }) + }