diff --git a/filter/filter.go b/filter/filter.go new file mode 100644 index 0000000..b732370 --- /dev/null +++ b/filter/filter.go @@ -0,0 +1,107 @@ +package filter + +import ( + "errors" + "io" + + "github.com/ulikunitz/xz/lzma" +) + +// ReaderConfig defines the parameters for the xz reader. The +// SingleStream parameter requests the reader to assume that the +// underlying stream contains only a single stream. +type ReaderConfig struct { + DictCap int +} + +// WriterConfig defines the configuration parameter for a writer. +type WriterConfig struct { + Properties *lzma.Properties + DictCap int + BufSize int + + // match algorithm + Matcher lzma.MatchAlgorithm +} + +// Filter represents a filter in the block header. +type Filter interface { + ID() uint64 + UnmarshalBinary(data []byte) error + MarshalBinary() (data []byte, err error) + Reader(r io.Reader, c *ReaderConfig) (fr io.Reader, err error) + WriteCloser(w io.WriteCloser, c *WriterConfig) (fw io.WriteCloser, err error) + // filter must be last filter + last() bool +} + +func NewFilterReader(c *ReaderConfig, r io.Reader, f []Filter) (fr io.Reader, + err error) { + + if err = VerifyFilters(f); err != nil { + return nil, err + } + + fr = r + for i := len(f) - 1; i >= 0; i-- { + fr, err = f[i].Reader(fr, c) + if err != nil { + return nil, err + } + } + return fr, nil +} + +// newFilterWriteCloser converts a filter list into a WriteCloser that +// can be used by a blockWriter. +func NewFilterWriteCloser(filterWriteConfig *WriterConfig, w io.Writer, f []Filter) (fw io.WriteCloser, err error) { + + if err = VerifyFilters(f); err != nil { + return nil, err + } + fw = nopWriteCloser(w) + for i := len(f) - 1; i >= 0; i-- { + fw, err = f[i].WriteCloser(fw, filterWriteConfig) + if err != nil { + return nil, err + } + } + return fw, nil +} + +// VerifyFilters checks the filter list for the length and the right +// sequence of filters. +func VerifyFilters(f []Filter) error { + if len(f) == 0 { + return errors.New("xz: no filters") + } + if len(f) > 4 { + return errors.New("xz: more than four filters") + } + for _, g := range f[:len(f)-1] { + if g.last() { + return errors.New("xz: last filter is not last") + } + } + if !f[len(f)-1].last() { + return errors.New("xz: wrong last filter") + } + return nil +} + +// nopWCloser implements a WriteCloser with a Close method not doing +// anything. +type nopWCloser struct { + io.Writer +} + +// Close returns nil and doesn't do anything else. +func (c nopWCloser) Close() error { + return nil +} + +// nopWriteCloser converts the Writer into a WriteCloser with a Close +// function that does nothing beside returning nil. +func nopWriteCloser(w io.Writer) io.WriteCloser { + return nopWCloser{w} +} diff --git a/lzmafilter.go b/filter/lzmafilter.go similarity index 66% rename from lzmafilter.go rename to filter/lzmafilter.go index 6f4aa2c..e735939 100644 --- a/lzmafilter.go +++ b/filter/lzmafilter.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package filter import ( "errors" @@ -14,37 +14,43 @@ import ( // LZMA filter constants. const ( - lzmaFilterID = 0x21 - lzmaFilterLen = 3 + LZMAFilterID = 0x21 + LZMAFilterLen = 3 ) -// lzmaFilter declares the LZMA2 filter information stored in an xz +func NewLZMAFilter(cap int64) *LZMAFilter { + return &LZMAFilter{dictCap: cap} +} + +// LZMAFilter declares the LZMA2 filter information stored in an xz // block header. -type lzmaFilter struct { +type LZMAFilter struct { dictCap int64 } +func (f LZMAFilter) GetDictCap() int64 { return f.dictCap } + // String returns a representation of the LZMA filter. -func (f lzmaFilter) String() string { +func (f LZMAFilter) String() string { return fmt.Sprintf("LZMA dict cap %#x", f.dictCap) } // id returns the ID for the LZMA2 filter. -func (f lzmaFilter) id() uint64 { return lzmaFilterID } +func (f LZMAFilter) ID() uint64 { return LZMAFilterID } -// MarshalBinary converts the lzmaFilter in its encoded representation. -func (f lzmaFilter) MarshalBinary() (data []byte, err error) { +// MarshalBinary converts the LZMAFilter in its encoded representation. +func (f LZMAFilter) MarshalBinary() (data []byte, err error) { c := lzma.EncodeDictCap(f.dictCap) - return []byte{lzmaFilterID, 1, c}, nil + return []byte{LZMAFilterID, 1, c}, nil } // UnmarshalBinary unmarshals the given data representation of the LZMA2 // filter. -func (f *lzmaFilter) UnmarshalBinary(data []byte) error { - if len(data) != lzmaFilterLen { +func (f *LZMAFilter) UnmarshalBinary(data []byte) error { + if len(data) != LZMAFilterLen { return errors.New("xz: data for LZMA2 filter has wrong length") } - if data[0] != lzmaFilterID { + if data[0] != LZMAFilterID { return errors.New("xz: wrong LZMA2 filter id") } if data[1] != 1 { @@ -59,8 +65,8 @@ func (f *lzmaFilter) UnmarshalBinary(data []byte) error { return nil } -// reader creates a new reader for the LZMA2 filter. -func (f lzmaFilter) reader(r io.Reader, c *ReaderConfig) (fr io.Reader, +// Reader creates a new reader for the LZMA2 filter. +func (f LZMAFilter) Reader(r io.Reader, c *ReaderConfig) (fr io.Reader, err error) { config := new(lzma.Reader2Config) @@ -83,8 +89,8 @@ func (f lzmaFilter) reader(r io.Reader, c *ReaderConfig) (fr io.Reader, return fr, nil } -// writeCloser creates a io.WriteCloser for the LZMA2 filter. -func (f lzmaFilter) writeCloser(w io.WriteCloser, c *WriterConfig, +// WriteCloser creates a io.WriteCloser for the LZMA2 filter. +func (f LZMAFilter) WriteCloser(w io.WriteCloser, c *WriterConfig, ) (fw io.WriteCloser, err error) { config := new(lzma.Writer2Config) if c != nil { @@ -114,4 +120,4 @@ func (f lzmaFilter) writeCloser(w io.WriteCloser, c *WriterConfig, // last returns true, because an LZMA2 filter must be the last filter in // the filter list. -func (f lzmaFilter) last() bool { return true } +func (f LZMAFilter) last() bool { return true } diff --git a/format.go b/format.go index edfec9a..4356921 100644 --- a/format.go +++ b/format.go @@ -1,732 +1,24 @@ -// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - package xz import ( - "bytes" - "crypto/sha256" - "errors" - "fmt" - "hash" - "hash/crc32" - "io" - - "github.com/ulikunitz/xz/lzma" + "github.com/ulikunitz/xz/xzinternals" ) -// allZeros checks whether a given byte slice has only zeros. -func allZeros(p []byte) bool { - for _, c := range p { - if c != 0 { - return false - } - } - return true -} - -// padLen returns the length of the padding required for the given -// argument. -func padLen(n int64) int { - k := int(n % 4) - if k > 0 { - k = 4 - k - } - return k -} - -/*** Header ***/ - -// headerMagic stores the magic bytes for the header -var headerMagic = []byte{0xfd, '7', 'z', 'X', 'Z', 0x00} - // HeaderLen provides the length of the xz file header. -const HeaderLen = 12 +const HeaderLen = xzinternals.HeaderLen // Constants for the checksum methods supported by xz. const ( - None byte = 0x0 - CRC32 = 0x1 - CRC64 = 0x4 - SHA256 = 0xa + None = xzinternals.None + CRC32 = xzinternals.CRC32 + CRC64 = xzinternals.CRC64 + SHA256 = xzinternals.SHA256 ) -// errInvalidFlags indicates that flags are invalid. -var errInvalidFlags = errors.New("xz: invalid flags") - -// verifyFlags returns the error errInvalidFlags if the value is -// invalid. -func verifyFlags(flags byte) error { - switch flags { - case None, CRC32, CRC64, SHA256: - return nil - default: - return errInvalidFlags - } -} - -// flagstrings maps flag values to strings. -var flagstrings = map[byte]string{ - None: "None", - CRC32: "CRC-32", - CRC64: "CRC-64", - SHA256: "SHA-256", -} - -// flagString returns the string representation for the given flags. -func flagString(flags byte) string { - s, ok := flagstrings[flags] - if !ok { - return "invalid" - } - return s -} - -// newHashFunc returns a function that creates hash instances for the -// hash method encoded in flags. -func newHashFunc(flags byte) (newHash func() hash.Hash, err error) { - switch flags { - case None: - newHash = newNoneHash - case CRC32: - newHash = newCRC32 - case CRC64: - newHash = newCRC64 - case SHA256: - newHash = sha256.New - default: - err = errInvalidFlags - } - return -} - -// header provides the actual content of the xz file header: the flags. -type header struct { - flags byte -} - -// Errors returned by readHeader. -var errHeaderMagic = errors.New("xz: invalid header magic bytes") - // ValidHeader checks whether data is a correct xz file header. The // length of data must be HeaderLen. func ValidHeader(data []byte) bool { - var h header + var h xzinternals.Header err := h.UnmarshalBinary(data) return err == nil } - -// String returns a string representation of the flags. -func (h header) String() string { - return flagString(h.flags) -} - -// UnmarshalBinary reads header from the provided data slice. -func (h *header) UnmarshalBinary(data []byte) error { - // header length - if len(data) != HeaderLen { - return errors.New("xz: wrong file header length") - } - - // magic header - if !bytes.Equal(headerMagic, data[:6]) { - return errHeaderMagic - } - - // checksum - crc := crc32.NewIEEE() - crc.Write(data[6:8]) - if uint32LE(data[8:]) != crc.Sum32() { - return errors.New("xz: invalid checksum for file header") - } - - // stream flags - if data[6] != 0 { - return errInvalidFlags - } - flags := data[7] - if err := verifyFlags(flags); err != nil { - return err - } - - h.flags = flags - return nil -} - -// MarshalBinary generates the xz file header. -func (h *header) MarshalBinary() (data []byte, err error) { - if err = verifyFlags(h.flags); err != nil { - return nil, err - } - - data = make([]byte, 12) - copy(data, headerMagic) - data[7] = h.flags - - crc := crc32.NewIEEE() - crc.Write(data[6:8]) - putUint32LE(data[8:], crc.Sum32()) - - return data, nil -} - -/*** Footer ***/ - -// footerLen defines the length of the footer. -const footerLen = 12 - -// footerMagic contains the footer magic bytes. -var footerMagic = []byte{'Y', 'Z'} - -// footer represents the content of the xz file footer. -type footer struct { - indexSize int64 - flags byte -} - -// String prints a string representation of the footer structure. -func (f footer) String() string { - return fmt.Sprintf("%s index size %d", flagString(f.flags), f.indexSize) -} - -// Minimum and maximum for the size of the index (backward size). -const ( - minIndexSize = 4 - maxIndexSize = (1 << 32) * 4 -) - -// MarshalBinary converts footer values into an xz file footer. Note -// that the footer value is checked for correctness. -func (f *footer) MarshalBinary() (data []byte, err error) { - if err = verifyFlags(f.flags); err != nil { - return nil, err - } - if !(minIndexSize <= f.indexSize && f.indexSize <= maxIndexSize) { - return nil, errors.New("xz: index size out of range") - } - if f.indexSize%4 != 0 { - return nil, errors.New( - "xz: index size not aligned to four bytes") - } - - data = make([]byte, footerLen) - - // backward size (index size) - s := (f.indexSize / 4) - 1 - putUint32LE(data[4:], uint32(s)) - // flags - data[9] = f.flags - // footer magic - copy(data[10:], footerMagic) - - // CRC-32 - crc := crc32.NewIEEE() - crc.Write(data[4:10]) - putUint32LE(data, crc.Sum32()) - - return data, nil -} - -// UnmarshalBinary sets the footer value by unmarshalling an xz file -// footer. -func (f *footer) UnmarshalBinary(data []byte) error { - if len(data) != footerLen { - return errors.New("xz: wrong footer length") - } - - // magic bytes - if !bytes.Equal(data[10:], footerMagic) { - return errors.New("xz: footer magic invalid") - } - - // CRC-32 - crc := crc32.NewIEEE() - crc.Write(data[4:10]) - if uint32LE(data) != crc.Sum32() { - return errors.New("xz: footer checksum error") - } - - var g footer - // backward size (index size) - g.indexSize = (int64(uint32LE(data[4:])) + 1) * 4 - - // flags - if data[8] != 0 { - return errInvalidFlags - } - g.flags = data[9] - if err := verifyFlags(g.flags); err != nil { - return err - } - - *f = g - return nil -} - -/*** Block Header ***/ - -// blockHeader represents the content of an xz block header. -type blockHeader struct { - compressedSize int64 - uncompressedSize int64 - filters []filter -} - -// String converts the block header into a string. -func (h blockHeader) String() string { - var buf bytes.Buffer - first := true - if h.compressedSize >= 0 { - fmt.Fprintf(&buf, "compressed size %d", h.compressedSize) - first = false - } - if h.uncompressedSize >= 0 { - if !first { - buf.WriteString(" ") - } - fmt.Fprintf(&buf, "uncompressed size %d", h.uncompressedSize) - first = false - } - for _, f := range h.filters { - if !first { - buf.WriteString(" ") - } - fmt.Fprintf(&buf, "filter %s", f) - first = false - } - return buf.String() -} - -// Masks for the block flags. -const ( - filterCountMask = 0x03 - compressedSizePresent = 0x40 - uncompressedSizePresent = 0x80 - reservedBlockFlags = 0x3C -) - -// errIndexIndicator signals that an index indicator (0x00) has been found -// instead of an expected block header indicator. -var errIndexIndicator = errors.New("xz: found index indicator") - -// readBlockHeader reads the block header. -func readBlockHeader(r io.Reader) (h *blockHeader, n int, err error) { - var buf bytes.Buffer - buf.Grow(20) - - // block header size - z, err := io.CopyN(&buf, r, 1) - n = int(z) - if err != nil { - return nil, n, err - } - s := buf.Bytes()[0] - if s == 0 { - return nil, n, errIndexIndicator - } - - // read complete header - headerLen := (int(s) + 1) * 4 - buf.Grow(headerLen - 1) - z, err = io.CopyN(&buf, r, int64(headerLen-1)) - n += int(z) - if err != nil { - return nil, n, err - } - - // unmarshal block header - h = new(blockHeader) - if err = h.UnmarshalBinary(buf.Bytes()); err != nil { - return nil, n, err - } - - return h, n, nil -} - -// readSizeInBlockHeader reads the uncompressed or compressed size -// fields in the block header. The present value informs the function -// whether the respective field is actually present in the header. -func readSizeInBlockHeader(r io.ByteReader, present bool) (n int64, err error) { - if !present { - return -1, nil - } - x, _, err := readUvarint(r) - if err != nil { - return 0, err - } - if x >= 1<<63 { - return 0, errors.New("xz: size overflow in block header") - } - return int64(x), nil -} - -// UnmarshalBinary unmarshals the block header. -func (h *blockHeader) UnmarshalBinary(data []byte) error { - // Check header length - s := data[0] - if data[0] == 0 { - return errIndexIndicator - } - headerLen := (int(s) + 1) * 4 - if len(data) != headerLen { - return fmt.Errorf("xz: data length %d; want %d", len(data), - headerLen) - } - n := headerLen - 4 - - // Check CRC-32 - crc := crc32.NewIEEE() - crc.Write(data[:n]) - if crc.Sum32() != uint32LE(data[n:]) { - return errors.New("xz: checksum error for block header") - } - - // Block header flags - flags := data[1] - if flags&reservedBlockFlags != 0 { - return errors.New("xz: reserved block header flags set") - } - - r := bytes.NewReader(data[2:n]) - - // Compressed size - var err error - h.compressedSize, err = readSizeInBlockHeader( - r, flags&compressedSizePresent != 0) - if err != nil { - return err - } - - // Uncompressed size - h.uncompressedSize, err = readSizeInBlockHeader( - r, flags&uncompressedSizePresent != 0) - if err != nil { - return err - } - - h.filters, err = readFilters(r, int(flags&filterCountMask)+1) - if err != nil { - return err - } - - // Check padding - // Since headerLen is a multiple of 4 we don't need to check - // alignment. - k := r.Len() - // The standard spec says that the padding should have not more - // than 3 bytes. However we found paddings of 4 or 5 in the - // wild. See https://github.com/ulikunitz/xz/pull/11 and - // https://github.com/ulikunitz/xz/issues/15 - // - // The only reasonable approach seems to be to ignore the - // padding size. We still check that all padding bytes are zero. - if !allZeros(data[n-k : n]) { - return errPadding - } - return nil -} - -// MarshalBinary marshals the binary header. -func (h *blockHeader) MarshalBinary() (data []byte, err error) { - if !(minFilters <= len(h.filters) && len(h.filters) <= maxFilters) { - return nil, errors.New("xz: filter count wrong") - } - for i, f := range h.filters { - if i < len(h.filters)-1 { - if f.id() == lzmaFilterID { - return nil, errors.New( - "xz: LZMA2 filter is not the last") - } - } else { - // last filter - if f.id() != lzmaFilterID { - return nil, errors.New("xz: " + - "last filter must be the LZMA2 filter") - } - } - } - - var buf bytes.Buffer - // header size must set at the end - buf.WriteByte(0) - - // flags - flags := byte(len(h.filters) - 1) - if h.compressedSize >= 0 { - flags |= compressedSizePresent - } - if h.uncompressedSize >= 0 { - flags |= uncompressedSizePresent - } - buf.WriteByte(flags) - - p := make([]byte, 10) - if h.compressedSize >= 0 { - k := putUvarint(p, uint64(h.compressedSize)) - buf.Write(p[:k]) - } - if h.uncompressedSize >= 0 { - k := putUvarint(p, uint64(h.uncompressedSize)) - buf.Write(p[:k]) - } - - for _, f := range h.filters { - fp, err := f.MarshalBinary() - if err != nil { - return nil, err - } - buf.Write(fp) - } - - // padding - for i := padLen(int64(buf.Len())); i > 0; i-- { - buf.WriteByte(0) - } - - // crc place holder - buf.Write(p[:4]) - - data = buf.Bytes() - if len(data)%4 != 0 { - panic("data length not aligned") - } - s := len(data)/4 - 1 - if !(1 < s && s <= 255) { - panic("wrong block header size") - } - data[0] = byte(s) - - crc := crc32.NewIEEE() - crc.Write(data[:len(data)-4]) - putUint32LE(data[len(data)-4:], crc.Sum32()) - - return data, nil -} - -// Constants used for marshalling and unmarshalling filters in the xz -// block header. -const ( - minFilters = 1 - maxFilters = 4 - minReservedID = 1 << 62 -) - -// filter represents a filter in the block header. -type filter interface { - id() uint64 - UnmarshalBinary(data []byte) error - MarshalBinary() (data []byte, err error) - reader(r io.Reader, c *ReaderConfig) (fr io.Reader, err error) - writeCloser(w io.WriteCloser, c *WriterConfig) (fw io.WriteCloser, err error) - // filter must be last filter - last() bool -} - -// readFilter reads a block filter from the block header. At this point -// in time only the LZMA2 filter is supported. -func readFilter(r io.Reader) (f filter, err error) { - br := lzma.ByteReader(r) - - // index - id, _, err := readUvarint(br) - if err != nil { - return nil, err - } - - var data []byte - switch id { - case lzmaFilterID: - data = make([]byte, lzmaFilterLen) - data[0] = lzmaFilterID - if _, err = io.ReadFull(r, data[1:]); err != nil { - return nil, err - } - f = new(lzmaFilter) - default: - if id >= minReservedID { - return nil, errors.New( - "xz: reserved filter id in block stream header") - } - return nil, errors.New("xz: invalid filter id") - } - if err = f.UnmarshalBinary(data); err != nil { - return nil, err - } - return f, err -} - -// readFilters reads count filters. At this point in time only the count -// 1 is supported. -func readFilters(r io.Reader, count int) (filters []filter, err error) { - if count != 1 { - return nil, errors.New("xz: unsupported filter count") - } - f, err := readFilter(r) - if err != nil { - return nil, err - } - return []filter{f}, err -} - -// writeFilters writes the filters. -func writeFilters(w io.Writer, filters []filter) (n int, err error) { - for _, f := range filters { - p, err := f.MarshalBinary() - if err != nil { - return n, err - } - k, err := w.Write(p) - n += k - if err != nil { - return n, err - } - } - return n, nil -} - -/*** Index ***/ - -// record describes a block in the xz file index. -type record struct { - unpaddedSize int64 - uncompressedSize int64 -} - -// readRecord reads an index record. -func readRecord(r io.ByteReader) (rec record, n int, err error) { - u, k, err := readUvarint(r) - n += k - if err != nil { - return rec, n, err - } - rec.unpaddedSize = int64(u) - if rec.unpaddedSize < 0 { - return rec, n, errors.New("xz: unpadded size negative") - } - - u, k, err = readUvarint(r) - n += k - if err != nil { - return rec, n, err - } - rec.uncompressedSize = int64(u) - if rec.uncompressedSize < 0 { - return rec, n, errors.New("xz: uncompressed size negative") - } - - return rec, n, nil -} - -// MarshalBinary converts an index record in its binary encoding. -func (rec *record) MarshalBinary() (data []byte, err error) { - // maximum length of a uvarint is 10 - p := make([]byte, 20) - n := putUvarint(p, uint64(rec.unpaddedSize)) - n += putUvarint(p[n:], uint64(rec.uncompressedSize)) - return p[:n], nil -} - -// writeIndex writes the index, a sequence of records. -func writeIndex(w io.Writer, index []record) (n int64, err error) { - crc := crc32.NewIEEE() - mw := io.MultiWriter(w, crc) - - // index indicator - k, err := mw.Write([]byte{0}) - n += int64(k) - if err != nil { - return n, err - } - - // number of records - p := make([]byte, 10) - k = putUvarint(p, uint64(len(index))) - k, err = mw.Write(p[:k]) - n += int64(k) - if err != nil { - return n, err - } - - // list of records - for _, rec := range index { - p, err := rec.MarshalBinary() - if err != nil { - return n, err - } - k, err = mw.Write(p) - n += int64(k) - if err != nil { - return n, err - } - } - - // index padding - k, err = mw.Write(make([]byte, padLen(int64(n)))) - n += int64(k) - if err != nil { - return n, err - } - - // crc32 checksum - putUint32LE(p, crc.Sum32()) - k, err = w.Write(p[:4]) - n += int64(k) - - return n, err -} - -// readIndexBody reads the index from the reader. It assumes that the -// index indicator has already been read. -func readIndexBody(r io.Reader) (records []record, n int64, err error) { - crc := crc32.NewIEEE() - // index indicator - crc.Write([]byte{0}) - - br := lzma.ByteReader(io.TeeReader(r, crc)) - - // number of records - u, k, err := readUvarint(br) - n += int64(k) - if err != nil { - return nil, n, err - } - recLen := int(u) - if recLen < 0 || uint64(recLen) != u { - return nil, n, errors.New("xz: record number overflow") - } - - // list of records - records = make([]record, recLen) - for i := range records { - records[i], k, err = readRecord(br) - n += int64(k) - if err != nil { - return nil, n, err - } - } - - p := make([]byte, padLen(int64(n+1)), 4) - k, err = io.ReadFull(br.(io.Reader), p) - n += int64(k) - if err != nil { - return nil, n, err - } - if !allZeros(p) { - return nil, n, errors.New("xz: non-zero byte in index padding") - } - - // crc32 - s := crc.Sum32() - p = p[:4] - k, err = io.ReadFull(br.(io.Reader), p) - n += int64(k) - if err != nil { - return records, n, err - } - if uint32LE(p) != s { - return nil, n, errors.New("xz: wrong checksum for index") - } - - return records, n, nil -} diff --git a/reader.go b/reader.go index 22cd6d5..a99b559 100644 --- a/reader.go +++ b/reader.go @@ -1,21 +1,14 @@ -// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - // Package xz supports the compression and decompression of xz files. It // supports version 1.0.4 of the specification without the non-LZMA2 // filters. See http://tukaani.org/xz/xz-file-format-1.0.4.txt package xz import ( - "bytes" "errors" - "fmt" - "hash" "io" - "github.com/ulikunitz/xz/internal/xlog" "github.com/ulikunitz/xz/lzma" + "github.com/ulikunitz/xz/xzinternals" ) // ReaderConfig defines the parameters for the xz reader. The @@ -51,18 +44,7 @@ type Reader struct { ReaderConfig xz io.Reader - sr *streamReader -} - -// streamReader decodes a single xz stream -type streamReader struct { - ReaderConfig - - xz io.Reader - br *blockReader - newHash func() hash.Hash - h header - index []record + sr *xzinternals.StreamReader } // NewReader creates a new xz reader using the default parameters. @@ -83,7 +65,7 @@ func (c ReaderConfig) NewReader(xz io.Reader) (r *Reader, err error) { ReaderConfig: c, xz: xz, } - if r.sr, err = c.newStreamReader(xz); err != nil { + if r.sr, err = xzinternals.NewStreamReader(xz, c.DictCap); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } @@ -107,8 +89,12 @@ func (r *Reader) Read(p []byte) (n int, err error) { return n, io.EOF } for { - r.sr, err = r.ReaderConfig.newStreamReader(r.xz) - if err != errPadding { + if err = r.ReaderConfig.Verify(); err != nil { + break + } + + r.sr, err = xzinternals.NewStreamReader(r.xz, r.ReaderConfig.DictCap) + if err != xzinternals.ErrPadding { break } } @@ -128,250 +114,3 @@ func (r *Reader) Read(p []byte) (n int, err error) { } return n, nil } - -var errPadding = errors.New("xz: padding (4 zero bytes) encountered") - -// newStreamReader creates a new xz stream reader using the given configuration -// parameters. NewReader reads and checks the header of the xz stream. -func (c ReaderConfig) newStreamReader(xz io.Reader) (r *streamReader, err error) { - if err = c.Verify(); err != nil { - return nil, err - } - data := make([]byte, HeaderLen) - if _, err := io.ReadFull(xz, data[:4]); err != nil { - return nil, err - } - if bytes.Equal(data[:4], []byte{0, 0, 0, 0}) { - return nil, errPadding - } - if _, err = io.ReadFull(xz, data[4:]); err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - r = &streamReader{ - ReaderConfig: c, - xz: xz, - index: make([]record, 0, 4), - } - if err = r.h.UnmarshalBinary(data); err != nil { - return nil, err - } - xlog.Debugf("xz header %s", r.h) - if r.newHash, err = newHashFunc(r.h.flags); err != nil { - return nil, err - } - return r, nil -} - -// errIndex indicates an error with the xz file index. -var errIndex = errors.New("xz: error in xz file index") - -// readTail reads the index body and the xz footer. -func (r *streamReader) readTail() error { - index, n, err := readIndexBody(r.xz) - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return err - } - if len(index) != len(r.index) { - return fmt.Errorf("xz: index length is %d; want %d", - len(index), len(r.index)) - } - for i, rec := range r.index { - if rec != index[i] { - return fmt.Errorf("xz: record %d is %v; want %v", - i, rec, index[i]) - } - } - - p := make([]byte, footerLen) - if _, err = io.ReadFull(r.xz, p); err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return err - } - var f footer - if err = f.UnmarshalBinary(p); err != nil { - return err - } - xlog.Debugf("xz footer %s", f) - if f.flags != r.h.flags { - return errors.New("xz: footer flags incorrect") - } - if f.indexSize != int64(n)+1 { - return errors.New("xz: index size in footer wrong") - } - return nil -} - -// Read reads actual data from the xz stream. -func (r *streamReader) Read(p []byte) (n int, err error) { - for n < len(p) { - if r.br == nil { - bh, hlen, err := readBlockHeader(r.xz) - if err != nil { - if err == errIndexIndicator { - if err = r.readTail(); err != nil { - return n, err - } - return n, io.EOF - } - return n, err - } - xlog.Debugf("block %v", *bh) - r.br, err = r.ReaderConfig.newBlockReader(r.xz, bh, - hlen, r.newHash()) - if err != nil { - return n, err - } - } - k, err := r.br.Read(p[n:]) - n += k - if err != nil { - if err == io.EOF { - r.index = append(r.index, r.br.record()) - r.br = nil - } else { - return n, err - } - } - } - return n, nil -} - -// countingReader is a reader that counts the bytes read. -type countingReader struct { - r io.Reader - n int64 -} - -// Read reads data from the wrapped reader and adds it to the n field. -func (lr *countingReader) Read(p []byte) (n int, err error) { - n, err = lr.r.Read(p) - lr.n += int64(n) - return n, err -} - -// blockReader supports the reading of a block. -type blockReader struct { - lxz countingReader - header *blockHeader - headerLen int - n int64 - hash hash.Hash - r io.Reader - err error -} - -// newBlockReader creates a new block reader. -func (c *ReaderConfig) newBlockReader(xz io.Reader, h *blockHeader, - hlen int, hash hash.Hash) (br *blockReader, err error) { - - br = &blockReader{ - lxz: countingReader{r: xz}, - header: h, - headerLen: hlen, - hash: hash, - } - - fr, err := c.newFilterReader(&br.lxz, h.filters) - if err != nil { - return nil, err - } - if br.hash.Size() != 0 { - br.r = io.TeeReader(fr, br.hash) - } else { - br.r = fr - } - - return br, nil -} - -// uncompressedSize returns the uncompressed size of the block. -func (br *blockReader) uncompressedSize() int64 { - return br.n -} - -// compressedSize returns the compressed size of the block. -func (br *blockReader) compressedSize() int64 { - return br.lxz.n -} - -// unpaddedSize computes the unpadded size for the block. -func (br *blockReader) unpaddedSize() int64 { - n := int64(br.headerLen) - n += br.compressedSize() - n += int64(br.hash.Size()) - return n -} - -// record returns the index record for the current block. -func (br *blockReader) record() record { - return record{br.unpaddedSize(), br.uncompressedSize()} -} - -// errBlockSize indicates that the size of the block in the block header -// is wrong. -var errBlockSize = errors.New("xz: wrong uncompressed size for block") - -// Read reads data from the block. -func (br *blockReader) Read(p []byte) (n int, err error) { - n, err = br.r.Read(p) - br.n += int64(n) - - u := br.header.uncompressedSize - if u >= 0 && br.uncompressedSize() > u { - return n, errors.New("xz: wrong uncompressed size for block") - } - c := br.header.compressedSize - if c >= 0 && br.compressedSize() > c { - return n, errors.New("xz: wrong compressed size for block") - } - if err != io.EOF { - return n, err - } - if br.uncompressedSize() < u || br.compressedSize() < c { - return n, io.ErrUnexpectedEOF - } - - s := br.hash.Size() - k := padLen(br.lxz.n) - q := make([]byte, k+s, k+2*s) - if _, err = io.ReadFull(br.lxz.r, q); err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return n, err - } - if !allZeros(q[:k]) { - return n, errors.New("xz: non-zero block padding") - } - checkSum := q[k:] - computedSum := br.hash.Sum(checkSum[s:]) - if !bytes.Equal(checkSum, computedSum) { - return n, errors.New("xz: checksum error for block") - } - return n, io.EOF -} - -func (c *ReaderConfig) newFilterReader(r io.Reader, f []filter) (fr io.Reader, - err error) { - - if err = verifyFilters(f); err != nil { - return nil, err - } - - fr = r - for i := len(f) - 1; i >= 0; i-- { - fr, err = f[i].reader(fr, c) - if err != nil { - return nil, err - } - } - return fr, nil -} diff --git a/reader_test.go b/reader_test.go index 45e725b..ac324a4 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1,7 +1,3 @@ -// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - package xz import ( diff --git a/writer.go b/writer.go index aec10df..c86f0f6 100644 --- a/writer.go +++ b/writer.go @@ -1,7 +1,3 @@ -// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - package xz import ( @@ -9,7 +5,9 @@ import ( "hash" "io" + "github.com/ulikunitz/xz/filter" "github.com/ulikunitz/xz/lzma" + "github.com/ulikunitz/xz/xzinternals" ) // WriterConfig describe the parameters for an xz writer. @@ -41,10 +39,10 @@ func (c *WriterConfig) fill() { c.BlockSize = maxInt64 } if c.CheckSum == 0 { - c.CheckSum = CRC64 + c.CheckSum = xzinternals.CRC64 } if c.NoCheckSum { - c.CheckSum = None + c.CheckSum = xzinternals.None } } @@ -67,82 +65,57 @@ func (c *WriterConfig) Verify() error { if c.BlockSize <= 0 { return errors.New("xz: block size out of range") } - if err := verifyFlags(c.CheckSum); err != nil { + if err := xzinternals.VerifyFlags(c.CheckSum); err != nil { return err } return nil } -// filters creates the filter list for the given parameters. -func (c *WriterConfig) filters() []filter { - return []filter{&lzmaFilter{int64(c.DictCap)}} -} - -// maxInt64 defines the maximum 64-bit signed integer. -const maxInt64 = 1<<63 - 1 - -// verifyFilters checks the filter list for the length and the right -// sequence of filters. -func verifyFilters(f []filter) error { - if len(f) == 0 { - return errors.New("xz: no filters") - } - if len(f) > 4 { - return errors.New("xz: more than four filters") - } - for _, g := range f[:len(f)-1] { - if g.last() { - return errors.New("xz: last filter is not last") - } +// newBlockWriter creates a new block writer. +func (c *WriterConfig) newBlockWriter(xz io.Writer, hash hash.Hash) (bw *xzinternals.BlockWriter, err error) { + bw = &xzinternals.BlockWriter{ + CXZ: xzinternals.NewCountingWriter(xz), + BlockSize: c.BlockSize, + Filters: c.filters(), + Hash: hash, } - if !f[len(f)-1].last() { - return errors.New("xz: wrong last filter") + + fwc := &filter.WriterConfig{ + Properties: c.Properties, + DictCap: c.DictCap, + BufSize: c.BufSize, + Matcher: c.Matcher, } - return nil -} -// newFilterWriteCloser converts a filter list into a WriteCloser that -// can be used by a blockWriter. -func (c *WriterConfig) newFilterWriteCloser(w io.Writer, f []filter) (fw io.WriteCloser, err error) { - if err = verifyFilters(f); err != nil { + bw.W, err = filter.NewFilterWriteCloser(fwc, &bw.CXZ, bw.Filters) + if err != nil { return nil, err } - fw = nopWriteCloser(w) - for i := len(f) - 1; i >= 0; i-- { - fw, err = f[i].writeCloser(fw, c) - if err != nil { - return nil, err - } + if bw.Hash.Size() != 0 { + bw.MW = io.MultiWriter(bw.W, bw.Hash) + } else { + bw.MW = bw.W } - return fw, nil -} - -// nopWCloser implements a WriteCloser with a Close method not doing -// anything. -type nopWCloser struct { - io.Writer + return bw, nil } -// Close returns nil and doesn't do anything else. -func (c nopWCloser) Close() error { - return nil +// filters creates the filter list for the given parameters. +func (c *WriterConfig) filters() []filter.Filter { + return []filter.Filter{filter.NewLZMAFilter(int64(c.DictCap))} } -// nopWriteCloser converts the Writer into a WriteCloser with a Close -// function that does nothing beside returning nil. -func nopWriteCloser(w io.Writer) io.WriteCloser { - return nopWCloser{w} -} +// maxInt64 defines the maximum 64-bit signed integer. +const maxInt64 = 1<<63 - 1 // Writer compresses data written to it. It is an io.WriteCloser. type Writer struct { WriterConfig xz io.Writer - bw *blockWriter + bw *xzinternals.BlockWriter newHash func() hash.Hash - h header - index []record + h xzinternals.Header + index []xzinternals.Record closed bool } @@ -153,7 +126,7 @@ func (w *Writer) newBlockWriter() error { if err != nil { return err } - if err = w.bw.writeHeader(w.xz); err != nil { + if err = w.bw.WriteHeader(w.xz); err != nil { return err } return nil @@ -166,7 +139,7 @@ func (w *Writer) closeBlockWriter() error { if err = w.bw.Close(); err != nil { return err } - w.index = append(w.index, w.bw.record()) + w.index = append(w.index, w.bw.Record()) return nil } @@ -183,10 +156,10 @@ func (c WriterConfig) NewWriter(xz io.Writer) (w *Writer, err error) { w = &Writer{ WriterConfig: c, xz: xz, - h: header{c.CheckSum}, - index: make([]record, 0, 4), + h: xzinternals.Header{c.CheckSum}, + index: make([]xzinternals.Record, 0, 4), } - if w.newHash, err = newHashFunc(c.CheckSum); err != nil { + if w.newHash, err = xzinternals.NewHashFunc(c.CheckSum); err != nil { return nil, err } data, err := w.h.MarshalBinary() @@ -203,12 +176,12 @@ func (c WriterConfig) NewWriter(xz io.Writer) (w *Writer, err error) { // Write compresses the uncompressed data provided. func (w *Writer) Write(p []byte) (n int, err error) { if w.closed { - return 0, errClosed + return 0, xzinternals.ErrClosed } for { k, err := w.bw.Write(p[n:]) n += k - if err != errNoSpace { + if err != xzinternals.ErrNoSpace { return n, err } if err = w.closeBlockWriter(); err != nil { @@ -224,7 +197,7 @@ func (w *Writer) Write(p []byte) (n int, err error) { // doesn't close the underlying writer. func (w *Writer) Close() error { if w.closed { - return errClosed + return xzinternals.ErrClosed } w.closed = true var err error @@ -232,8 +205,8 @@ func (w *Writer) Close() error { return err } - f := footer{flags: w.h.flags} - if f.indexSize, err = writeIndex(w.xz, w.index); err != nil { + f := xzinternals.Footer{Flags: w.h.Flags} + if f.IndexSize, err = xzinternals.WriteIndex(w.xz, w.index); err != nil { return err } data, err := f.MarshalBinary() @@ -245,151 +218,3 @@ func (w *Writer) Close() error { } return nil } - -// countingWriter is a writer that counts all data written to it. -type countingWriter struct { - w io.Writer - n int64 -} - -// Write writes data to the countingWriter. -func (cw *countingWriter) Write(p []byte) (n int, err error) { - n, err = cw.w.Write(p) - cw.n += int64(n) - if err == nil && cw.n < 0 { - return n, errors.New("xz: counter overflow") - } - return -} - -// blockWriter is writes a single block. -type blockWriter struct { - cxz countingWriter - // mw combines io.WriteCloser w and the hash. - mw io.Writer - w io.WriteCloser - n int64 - blockSize int64 - closed bool - headerLen int - - filters []filter - hash hash.Hash -} - -// newBlockWriter creates a new block writer. -func (c *WriterConfig) newBlockWriter(xz io.Writer, hash hash.Hash) (bw *blockWriter, err error) { - bw = &blockWriter{ - cxz: countingWriter{w: xz}, - blockSize: c.BlockSize, - filters: c.filters(), - hash: hash, - } - bw.w, err = c.newFilterWriteCloser(&bw.cxz, bw.filters) - if err != nil { - return nil, err - } - if bw.hash.Size() != 0 { - bw.mw = io.MultiWriter(bw.w, bw.hash) - } else { - bw.mw = bw.w - } - return bw, nil -} - -// writeHeader writes the header. If the function is called after Close -// the commpressedSize and uncompressedSize fields will be filled. -func (bw *blockWriter) writeHeader(w io.Writer) error { - h := blockHeader{ - compressedSize: -1, - uncompressedSize: -1, - filters: bw.filters, - } - if bw.closed { - h.compressedSize = bw.compressedSize() - h.uncompressedSize = bw.uncompressedSize() - } - data, err := h.MarshalBinary() - if err != nil { - return err - } - if _, err = w.Write(data); err != nil { - return err - } - bw.headerLen = len(data) - return nil -} - -// compressed size returns the amount of data written to the underlying -// stream. -func (bw *blockWriter) compressedSize() int64 { - return bw.cxz.n -} - -// uncompressedSize returns the number of data written to the -// blockWriter -func (bw *blockWriter) uncompressedSize() int64 { - return bw.n -} - -// unpaddedSize returns the sum of the header length, the uncompressed -// size of the block and the hash size. -func (bw *blockWriter) unpaddedSize() int64 { - if bw.headerLen <= 0 { - panic("xz: block header not written") - } - n := int64(bw.headerLen) - n += bw.compressedSize() - n += int64(bw.hash.Size()) - return n -} - -// record returns the record for the current stream. Call Close before -// calling this method. -func (bw *blockWriter) record() record { - return record{bw.unpaddedSize(), bw.uncompressedSize()} -} - -var errClosed = errors.New("xz: writer already closed") - -var errNoSpace = errors.New("xz: no space") - -// Write writes uncompressed data to the block writer. -func (bw *blockWriter) Write(p []byte) (n int, err error) { - if bw.closed { - return 0, errClosed - } - - t := bw.blockSize - bw.n - if int64(len(p)) > t { - err = errNoSpace - p = p[:t] - } - - var werr error - n, werr = bw.mw.Write(p) - bw.n += int64(n) - if werr != nil { - return n, werr - } - return n, err -} - -// Close closes the writer. -func (bw *blockWriter) Close() error { - if bw.closed { - return errClosed - } - bw.closed = true - if err := bw.w.Close(); err != nil { - return err - } - s := bw.hash.Size() - k := padLen(bw.cxz.n) - p := make([]byte, k+s) - bw.hash.Sum(p[k:k]) - if _, err := bw.cxz.w.Write(p); err != nil { - return err - } - return nil -} diff --git a/bits.go b/xzinternals/bits.go similarity index 98% rename from bits.go rename to xzinternals/bits.go index 364213d..2f0e849 100644 --- a/bits.go +++ b/xzinternals/bits.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package xzinternals import ( "errors" diff --git a/bits_test.go b/xzinternals/bits_test.go similarity index 97% rename from bits_test.go rename to xzinternals/bits_test.go index 8530056..e44d700 100644 --- a/bits_test.go +++ b/xzinternals/bits_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package xzinternals import ( "bytes" diff --git a/crc.go b/xzinternals/crc.go similarity index 98% rename from crc.go rename to xzinternals/crc.go index 638774a..9e3a1fb 100644 --- a/crc.go +++ b/xzinternals/crc.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package xzinternals import ( "hash" diff --git a/example.go b/xzinternals/example.go similarity index 100% rename from example.go rename to xzinternals/example.go diff --git a/xzinternals/format.go b/xzinternals/format.go new file mode 100644 index 0000000..d470cdd --- /dev/null +++ b/xzinternals/format.go @@ -0,0 +1,716 @@ +// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xzinternals + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "hash" + "hash/crc32" + "io" + + "github.com/ulikunitz/xz/filter" + "github.com/ulikunitz/xz/lzma" +) + +// allZeros checks whether a given byte slice has only zeros. +func allZeros(p []byte) bool { + for _, c := range p { + if c != 0 { + return false + } + } + return true +} + +// padLen returns the length of the padding required for the given +// argument. +func padLen(n int64) int { + k := int(n % 4) + if k > 0 { + k = 4 - k + } + return k +} + +/*** Header ***/ + +// headerMagic stores the magic bytes for the header +var headerMagic = []byte{0xfd, '7', 'z', 'X', 'Z', 0x00} + +// HeaderLen provides the length of the xz file header. +const HeaderLen = 12 + +// Constants for the checksum methods supported by xz. +const ( + None byte = 0x0 + CRC32 = 0x1 + CRC64 = 0x4 + SHA256 = 0xa +) + +// errInvalidFlags indicates that flags are invalid. +var errInvalidFlags = errors.New("xz: invalid flags") + +// VerifyFlags returns the error errInvalidFlags if the value is +// invalid. +func VerifyFlags(flags byte) error { + switch flags { + case None, CRC32, CRC64, SHA256: + return nil + default: + return errInvalidFlags + } +} + +// flagstrings maps flag values to strings. +var flagstrings = map[byte]string{ + None: "None", + CRC32: "CRC-32", + CRC64: "CRC-64", + SHA256: "SHA-256", +} + +// flagString returns the string representation for the given flags. +func flagString(flags byte) string { + s, ok := flagstrings[flags] + if !ok { + return "invalid" + } + return s +} + +// NewHashFunc returns a function that creates hash instances for the +// hash method encoded in flags. +func NewHashFunc(flags byte) (newHash func() hash.Hash, err error) { + switch flags { + case None: + newHash = newNoneHash + case CRC32: + newHash = newCRC32 + case CRC64: + newHash = newCRC64 + case SHA256: + newHash = sha256.New + default: + err = errInvalidFlags + } + return +} + +// Header provides the actual content of the xz file Header: the flags. +type Header struct { + Flags byte +} + +// Errors returned by readHeader. +var errHeaderMagic = errors.New("xz: invalid header magic bytes") + +// String returns a string representation of the flags. +func (h Header) String() string { + return flagString(h.Flags) +} + +// UnmarshalBinary reads header from the provided data slice. +func (h *Header) UnmarshalBinary(data []byte) error { + // header length + if len(data) != HeaderLen { + return errors.New("xz: wrong file header length") + } + + // magic header + if !bytes.Equal(headerMagic, data[:6]) { + return errHeaderMagic + } + + // checksum + crc := crc32.NewIEEE() + crc.Write(data[6:8]) + if uint32LE(data[8:]) != crc.Sum32() { + return errors.New("xz: invalid checksum for file header") + } + + // stream flags + if data[6] != 0 { + return errInvalidFlags + } + flags := data[7] + if err := VerifyFlags(flags); err != nil { + return err + } + + h.Flags = flags + return nil +} + +// MarshalBinary generates the xz file header. +func (h *Header) MarshalBinary() (data []byte, err error) { + if err = VerifyFlags(h.Flags); err != nil { + return nil, err + } + + data = make([]byte, 12) + copy(data, headerMagic) + data[7] = h.Flags + + crc := crc32.NewIEEE() + crc.Write(data[6:8]) + putUint32LE(data[8:], crc.Sum32()) + + return data, nil +} + +/*** Footer ***/ + +// footerLen defines the length of the footer. +const footerLen = 12 + +// footerMagic contains the footer magic bytes. +var footerMagic = []byte{'Y', 'Z'} + +// Footer represents the content of the xz file Footer. +type Footer struct { + IndexSize int64 + Flags byte +} + +// String prints a string representation of the footer structure. +func (f Footer) String() string { + return fmt.Sprintf("%s index size %d", flagString(f.Flags), f.IndexSize) +} + +// Minimum and maximum for the size of the index (backward size). +const ( + minIndexSize = 4 + maxIndexSize = 1 << 32 * 4 +) + +// MarshalBinary converts footer values into an xz file footer. Note +// that the footer value is checked for correctness. +func (f *Footer) MarshalBinary() (data []byte, err error) { + if err = VerifyFlags(f.Flags); err != nil { + return nil, err + } + if !(minIndexSize <= f.IndexSize && f.IndexSize <= maxIndexSize) { + return nil, errors.New("xz: index size out of range") + } + if f.IndexSize%4 != 0 { + return nil, errors.New( + "xz: index size not aligned to four bytes") + } + + data = make([]byte, footerLen) + + // backward size (index size) + s := f.IndexSize/4 - 1 + putUint32LE(data[4:], uint32(s)) + // flags + data[9] = f.Flags + // footer magic + copy(data[10:], footerMagic) + + // CRC-32 + crc := crc32.NewIEEE() + crc.Write(data[4:10]) + putUint32LE(data, crc.Sum32()) + + return data, nil +} + +// UnmarshalBinary sets the footer value by unmarshalling an xz file +// footer. +func (f *Footer) UnmarshalBinary(data []byte) error { + if len(data) != footerLen { + return errors.New("xz: wrong footer length") + } + + // magic bytes + if !bytes.Equal(data[10:], footerMagic) { + return errors.New("xz: footer magic invalid") + } + + // CRC-32 + crc := crc32.NewIEEE() + crc.Write(data[4:10]) + if uint32LE(data) != crc.Sum32() { + return errors.New("xz: footer checksum error") + } + + var g Footer + // backward size (index size) + g.IndexSize = (int64(uint32LE(data[4:])) + 1) * 4 + + // flags + if data[8] != 0 { + return errInvalidFlags + } + g.Flags = data[9] + if err := VerifyFlags(g.Flags); err != nil { + return err + } + + *f = g + return nil +} + +/*** Block Header ***/ + +// BlockHeader represents the content of an xz block header. +type BlockHeader struct { + CompressedSize int64 + UncompressedSize int64 + Filters []filter.Filter +} + +// String converts the block header into a string. +func (h BlockHeader) String() string { + var buf bytes.Buffer + first := true + if h.CompressedSize >= 0 { + fmt.Fprintf(&buf, "compressed size %d", h.CompressedSize) + first = false + } + if h.UncompressedSize >= 0 { + if !first { + buf.WriteString(" ") + } + fmt.Fprintf(&buf, "uncompressed size %d", h.UncompressedSize) + first = false + } + for _, f := range h.Filters { + if !first { + buf.WriteString(" ") + } + fmt.Fprintf(&buf, "filter %s", f) + first = false + } + return buf.String() +} + +// Masks for the block flags. +const ( + filterCountMask = 0x03 + compressedSizePresent = 0x40 + uncompressedSizePresent = 0x80 + reservedBlockFlags = 0x3C +) + +// errIndexIndicator signals that an index indicator (0x00) has been found +// instead of an expected block header indicator. +var errIndexIndicator = errors.New("xz: found index indicator") + +// readBlockHeader reads the block header. +func readBlockHeader(r io.Reader) (h *BlockHeader, n int, err error) { + var buf bytes.Buffer + buf.Grow(20) + + // block header size + z, err := io.CopyN(&buf, r, 1) + n = int(z) + if err != nil { + return nil, n, err + } + s := buf.Bytes()[0] + if s == 0 { + return nil, n, errIndexIndicator + } + + // read complete header + headerLen := (int(s) + 1) * 4 + buf.Grow(headerLen - 1) + z, err = io.CopyN(&buf, r, int64(headerLen-1)) + n += int(z) + if err != nil { + return nil, n, err + } + + // unmarshal block header + h = new(BlockHeader) + if err = h.UnmarshalBinary(buf.Bytes()); err != nil { + return nil, n, err + } + + return h, n, nil +} + +// readSizeInBlockHeader reads the uncompressed or compressed size +// fields in the block header. The present value informs the function +// whether the respective field is actually present in the header. +func readSizeInBlockHeader(r io.ByteReader, present bool) (n int64, err error) { + if !present { + return -1, nil + } + x, _, err := readUvarint(r) + if err != nil { + return 0, err + } + if x >= 1<<63 { + return 0, errors.New("xz: size overflow in block header") + } + return int64(x), nil +} + +var ErrPadding = errors.New("xz: padding (4 zero bytes) encountered") + +// UnmarshalBinary unmarshals the block header. +func (h *BlockHeader) UnmarshalBinary(data []byte) error { + // Check header length + s := data[0] + if data[0] == 0 { + return errIndexIndicator + } + headerLen := (int(s) + 1) * 4 + if len(data) != headerLen { + return fmt.Errorf("xz: data length %d; want %d", len(data), + headerLen) + } + n := headerLen - 4 + + // Check CRC-32 + crc := crc32.NewIEEE() + crc.Write(data[:n]) + if crc.Sum32() != uint32LE(data[n:]) { + return errors.New("xz: checksum error for block header") + } + + // Block header flags + flags := data[1] + if flags&reservedBlockFlags != 0 { + return errors.New("xz: reserved block header flags set") + } + + r := bytes.NewReader(data[2:n]) + + // Compressed size + var err error + h.CompressedSize, err = readSizeInBlockHeader( + r, flags&compressedSizePresent != 0) + if err != nil { + return err + } + + // Uncompressed size + h.UncompressedSize, err = readSizeInBlockHeader( + r, flags&uncompressedSizePresent != 0) + if err != nil { + return err + } + + h.Filters, err = readFilters(r, int(flags&filterCountMask)+1) + if err != nil { + return err + } + + // Check padding + // Since headerLen is a multiple of 4 we don't need to check + // alignment. + k := r.Len() + // The standard spec says that the padding should have not more + // than 3 bytes. However we found paddings of 4 or 5 in the + // wild. See https://github.com/ulikunitz/xz/pull/11 and + // https://github.com/ulikunitz/xz/issues/15 + // + // The only reasonable approach seems to be to ignore the + // padding size. We still check that all padding bytes are zero. + if !allZeros(data[n-k : n]) { + return ErrPadding + } + return nil +} + +// MarshalBinary marshals the binary header. +func (h *BlockHeader) MarshalBinary() (data []byte, err error) { + if !(minFilters <= len(h.Filters) && len(h.Filters) <= maxFilters) { + return nil, errors.New("xz: filter count wrong") + } + for i, f := range h.Filters { + if i < len(h.Filters)-1 { + if f.ID() == filter.LZMAFilterID { + return nil, errors.New( + "xz: LZMA2 filter is not the last") + } + } else { + // last filter + if f.ID() != filter.LZMAFilterID { + return nil, errors.New("xz: " + + "last filter must be the LZMA2 filter") + } + } + } + + var buf bytes.Buffer + // header size must set at the end + buf.WriteByte(0) + + // flags + flags := byte(len(h.Filters) - 1) + if h.CompressedSize >= 0 { + flags |= compressedSizePresent + } + if h.UncompressedSize >= 0 { + flags |= uncompressedSizePresent + } + buf.WriteByte(flags) + + p := make([]byte, 10) + if h.CompressedSize >= 0 { + k := putUvarint(p, uint64(h.CompressedSize)) + buf.Write(p[:k]) + } + if h.UncompressedSize >= 0 { + k := putUvarint(p, uint64(h.UncompressedSize)) + buf.Write(p[:k]) + } + + for _, f := range h.Filters { + fp, err := f.MarshalBinary() + if err != nil { + return nil, err + } + buf.Write(fp) + } + + // padding + for i := padLen(int64(buf.Len())); i > 0; i-- { + buf.WriteByte(0) + } + + // crc place holder + buf.Write(p[:4]) + + data = buf.Bytes() + if len(data)%4 != 0 { + panic("data length not aligned") + } + s := len(data)/4 - 1 + if !(1 < s && s <= 255) { + panic("wrong block header size") + } + data[0] = byte(s) + + crc := crc32.NewIEEE() + crc.Write(data[:len(data)-4]) + putUint32LE(data[len(data)-4:], crc.Sum32()) + + return data, nil +} + +// Constants used for marshalling and unmarshalling filters in the xz +// block header. +const ( + minFilters = 1 + maxFilters = 4 + minReservedID = 1 << 62 +) + +// readFilter reads a block filter from the block header. At this point +// in time only the LZMA2 filter is supported. +func readFilter(r io.Reader) (f filter.Filter, err error) { + br := lzma.ByteReader(r) + + // index + id, _, err := readUvarint(br) + if err != nil { + return nil, err + } + + var data []byte + switch id { + case filter.LZMAFilterID: + data = make([]byte, filter.LZMAFilterLen) + data[0] = filter.LZMAFilterID + if _, err = io.ReadFull(r, data[1:]); err != nil { + return nil, err + } + f = new(filter.LZMAFilter) + default: + if id >= minReservedID { + return nil, errors.New( + "xz: reserved filter id in block stream header") + } + return nil, errors.New("xz: invalid filter id") + } + if err = f.UnmarshalBinary(data); err != nil { + return nil, err + } + return f, err +} + +// readFilters reads count filters. At this point in time only the count +// 1 is supported. +func readFilters(r io.Reader, count int) (filters []filter.Filter, err error) { + if count != 1 { + return nil, errors.New("xz: unsupported filter count") + } + f, err := readFilter(r) + if err != nil { + return nil, err + } + return []filter.Filter{f}, err +} + +// writeFilters writes the filters. +func writeFilters(w io.Writer, filters []filter.Filter) (n int, err error) { + for _, f := range filters { + p, err := f.MarshalBinary() + if err != nil { + return n, err + } + k, err := w.Write(p) + n += k + if err != nil { + return n, err + } + } + return n, nil +} + +/*** Index ***/ + +// Record describes a block in the xz file index. +type Record struct { + unpaddedSize int64 + uncompressedSize int64 +} + +// readRecord reads an index record. +func readRecord(r io.ByteReader) (rec Record, n int, err error) { + u, k, err := readUvarint(r) + n += k + if err != nil { + return rec, n, err + } + rec.unpaddedSize = int64(u) + if rec.unpaddedSize < 0 { + return rec, n, errors.New("xz: unpadded size negative") + } + + u, k, err = readUvarint(r) + n += k + if err != nil { + return rec, n, err + } + rec.uncompressedSize = int64(u) + if rec.uncompressedSize < 0 { + return rec, n, errors.New("xz: uncompressed size negative") + } + + return rec, n, nil +} + +// MarshalBinary converts an index record in its binary encoding. +func (rec *Record) MarshalBinary() (data []byte, err error) { + // maximum length of a uvarint is 10 + p := make([]byte, 20) + n := putUvarint(p, uint64(rec.unpaddedSize)) + n += putUvarint(p[n:], uint64(rec.uncompressedSize)) + return p[:n], nil +} + +// WriteIndex writes the index, a sequence of records. +func WriteIndex(w io.Writer, index []Record) (n int64, err error) { + crc := crc32.NewIEEE() + mw := io.MultiWriter(w, crc) + + // index indicator + k, err := mw.Write([]byte{0}) + n += int64(k) + if err != nil { + return n, err + } + + // number of records + p := make([]byte, 10) + k = putUvarint(p, uint64(len(index))) + k, err = mw.Write(p[:k]) + n += int64(k) + if err != nil { + return n, err + } + + // list of records + for _, rec := range index { + p, err := rec.MarshalBinary() + if err != nil { + return n, err + } + k, err = mw.Write(p) + n += int64(k) + if err != nil { + return n, err + } + } + + // index padding + k, err = mw.Write(make([]byte, padLen(int64(n)))) + n += int64(k) + if err != nil { + return n, err + } + + // crc32 checksum + putUint32LE(p, crc.Sum32()) + k, err = w.Write(p[:4]) + n += int64(k) + + return n, err +} + +// readIndexBody reads the index from the reader. It assumes that the +// index indicator has already been read. +func readIndexBody(r io.Reader) (records []Record, n int64, err error) { + crc := crc32.NewIEEE() + // index indicator + crc.Write([]byte{0}) + + br := lzma.ByteReader(io.TeeReader(r, crc)) + + // number of records + u, k, err := readUvarint(br) + n += int64(k) + if err != nil { + return nil, n, err + } + recLen := int(u) + if recLen < 0 || uint64(recLen) != u { + return nil, n, errors.New("xz: record number overflow") + } + + // list of records + records = make([]Record, recLen) + for i := range records { + records[i], k, err = readRecord(br) + n += int64(k) + if err != nil { + return nil, n, err + } + } + + p := make([]byte, padLen(int64(n+1)), 4) + k, err = io.ReadFull(br.(io.Reader), p) + n += int64(k) + if err != nil { + return nil, n, err + } + if !allZeros(p) { + return nil, n, errors.New("xz: non-zero byte in index padding") + } + + // crc32 + s := crc.Sum32() + p = p[:4] + k, err = io.ReadFull(br.(io.Reader), p) + n += int64(k) + if err != nil { + return records, n, err + } + if uint32LE(p) != s { + return nil, n, errors.New("xz: wrong checksum for index") + } + + return records, n, nil +} diff --git a/format_test.go b/xzinternals/format_test.go similarity index 75% rename from format_test.go rename to xzinternals/format_test.go index 0b875d3..d9b0e66 100644 --- a/format_test.go +++ b/xzinternals/format_test.go @@ -2,20 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package xzinternals import ( "bytes" "testing" + + "github.com/ulikunitz/xz/filter" ) func TestHeader(t *testing.T) { - h := header{flags: CRC32} + h := Header{Flags: CRC32} data, err := h.MarshalBinary() if err != nil { t.Fatalf("MarshalBinary error %s", err) } - var g header + var g Header if err = g.UnmarshalBinary(data); err != nil { t.Fatalf("UnmarshalBinary error %s", err) } @@ -25,12 +27,12 @@ func TestHeader(t *testing.T) { } func TestFooter(t *testing.T) { - f := footer{indexSize: 64, flags: CRC32} + f := Footer{IndexSize: 64, Flags: CRC32} data, err := f.MarshalBinary() if err != nil { t.Fatalf("MarshalBinary error %s", err) } - var g footer + var g Footer if err = g.UnmarshalBinary(data); err != nil { t.Fatalf("UnmarshalBinary error %s", err) } @@ -40,7 +42,7 @@ func TestFooter(t *testing.T) { } func TestRecord(t *testing.T) { - r := record{1234567, 10000} + r := Record{1234567, 10000} p, err := r.MarshalBinary() if err != nil { t.Fatalf("MarshalBinary error %s", err) @@ -65,10 +67,10 @@ func TestRecord(t *testing.T) { } func TestIndex(t *testing.T) { - records := []record{{1234, 1}, {2345, 2}} + records := []Record{{1234, 1}, {2345, 2}} var buf bytes.Buffer - n, err := writeIndex(&buf, records) + n, err := WriteIndex(&buf, records) if err != nil { t.Fatalf("writeIndex error %s", err) } @@ -103,10 +105,10 @@ func TestIndex(t *testing.T) { } func TestBlockHeader(t *testing.T) { - h := blockHeader{ - compressedSize: 1234, - uncompressedSize: -1, - filters: []filter{&lzmaFilter{4096}}, + h := BlockHeader{ + CompressedSize: 1234, + UncompressedSize: -1, + Filters: []filter.Filter{filter.NewLZMAFilter(4096)}, } data, err := h.MarshalBinary() if err != nil { @@ -122,21 +124,21 @@ func TestBlockHeader(t *testing.T) { t.Fatalf("readBlockHeader returns %d bytes; want %d", n, len(data)) } - if g.compressedSize != h.compressedSize { + if g.CompressedSize != h.CompressedSize { t.Errorf("got compressedSize %d; want %d", - g.compressedSize, h.compressedSize) + g.CompressedSize, h.CompressedSize) } - if g.uncompressedSize != h.uncompressedSize { + if g.UncompressedSize != h.UncompressedSize { t.Errorf("got uncompressedSize %d; want %d", - g.uncompressedSize, h.uncompressedSize) + g.UncompressedSize, h.UncompressedSize) } - if len(g.filters) != len(h.filters) { + if len(g.Filters) != len(h.Filters) { t.Errorf("got len(filters) %d; want %d", - len(g.filters), len(h.filters)) + len(g.Filters), len(h.Filters)) } - glf := g.filters[0].(*lzmaFilter) - hlf := h.filters[0].(*lzmaFilter) - if glf.dictCap != hlf.dictCap { - t.Errorf("got dictCap %d; want %d", glf.dictCap, hlf.dictCap) + glf := g.Filters[0].(*filter.LZMAFilter) + hlf := h.Filters[0].(*filter.LZMAFilter) + if glf.GetDictCap() != hlf.GetDictCap() { + t.Errorf("got dictCap %d; want %d", glf.GetDictCap(), hlf.GetDictCap()) } } diff --git a/none-check.go b/xzinternals/none-check.go similarity index 96% rename from none-check.go rename to xzinternals/none-check.go index e12d8e4..f17df4b 100644 --- a/none-check.go +++ b/xzinternals/none-check.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package xzinternals import "hash" diff --git a/none-check_test.go b/xzinternals/none-check_test.go similarity index 94% rename from none-check_test.go rename to xzinternals/none-check_test.go index 6761aa4..b2b4022 100644 --- a/none-check_test.go +++ b/xzinternals/none-check_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xz +package xzinternals import ( "bytes" diff --git a/xzinternals/reader.go b/xzinternals/reader.go new file mode 100644 index 0000000..4ccff44 --- /dev/null +++ b/xzinternals/reader.go @@ -0,0 +1,256 @@ +// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package xzinternals + +import ( + "bytes" + "errors" + "fmt" + "hash" + "io" + + "github.com/ulikunitz/xz/filter" + "github.com/ulikunitz/xz/internal/xlog" +) + +// StreamReader decodes a single xz stream +type StreamReader struct { + // ReaderConfig + dictCap int + + xz io.Reader + br *BlockReader + newHash func() hash.Hash + h Header + index []Record +} + +// errIndex indicates an error with the xz file index. +var errIndex = errors.New("xz: error in xz file index") + +// readTail reads the index body and the xz footer. +func (r *StreamReader) readTail() error { + index, n, err := readIndexBody(r.xz) + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err + } + if len(index) != len(r.index) { + return fmt.Errorf("xz: index length is %d; want %d", + len(index), len(r.index)) + } + for i, rec := range r.index { + if rec != index[i] { + return fmt.Errorf("xz: record %d is %v; want %v", + i, rec, index[i]) + } + } + + p := make([]byte, footerLen) + if _, err = io.ReadFull(r.xz, p); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err + } + var f Footer + if err = f.UnmarshalBinary(p); err != nil { + return err + } + xlog.Debugf("xz footer %s", f) + if f.Flags != r.h.Flags { + return errors.New("xz: footer flags incorrect") + } + if f.IndexSize != int64(n)+1 { + return errors.New("xz: index size in footer wrong") + } + return nil +} + +// Read reads actual data from the xz stream. +func (r *StreamReader) Read(p []byte) (n int, err error) { + for n < len(p) { + if r.br == nil { + bh, hlen, err := readBlockHeader(r.xz) + if err != nil { + if err == errIndexIndicator { + if err = r.readTail(); err != nil { + return n, err + } + return n, io.EOF + } + return n, err + } + xlog.Debugf("block %v", *bh) + r.br, err = NewBlockReader(r.xz, bh, + hlen, r.newHash(), r.dictCap) + if err != nil { + return n, err + } + } + k, err := r.br.Read(p[n:]) + n += k + if err != nil { + if err == io.EOF { + r.index = append(r.index, r.br.record()) + r.br = nil + } else { + return n, err + } + } + } + return n, nil +} + +// countingReader is a reader that counts the bytes read. +type countingReader struct { + r io.Reader + n int64 +} + +// Read reads data from the wrapped reader and adds it to the n field. +func (lr *countingReader) Read(p []byte) (n int, err error) { + n, err = lr.r.Read(p) + lr.n += int64(n) + return n, err +} + +// blockReader supports the reading of a block. +type BlockReader struct { + lxz countingReader + header *BlockHeader + headerLen int + n int64 + hash hash.Hash + r io.Reader + err error +} + +// NewBlockReader creates a new block reader. +func NewBlockReader(xz io.Reader, h *BlockHeader, + hlen int, hash hash.Hash, dictCap int) (br *BlockReader, err error) { + + br = &BlockReader{ + lxz: countingReader{r: xz}, + header: h, + headerLen: hlen, + hash: hash, + } + + config := filter.ReaderConfig{ + DictCap: dictCap, + } + + fr, err := filter.NewFilterReader(&config, &br.lxz, h.Filters) + if err != nil { + return nil, err + } + if br.hash.Size() != 0 { + br.r = io.TeeReader(fr, br.hash) + } else { + br.r = fr + } + + return br, nil +} + +// uncompressedSize returns the uncompressed size of the block. +func (br *BlockReader) uncompressedSize() int64 { + return br.n +} + +// compressedSize returns the compressed size of the block. +func (br *BlockReader) compressedSize() int64 { + return br.lxz.n +} + +// unpaddedSize computes the unpadded size for the block. +func (br *BlockReader) unpaddedSize() int64 { + n := int64(br.headerLen) + n += br.compressedSize() + n += int64(br.hash.Size()) + return n +} + +// record returns the index record for the current block. +func (br *BlockReader) record() Record { + return Record{br.unpaddedSize(), br.uncompressedSize()} +} + +// errBlockSize indicates that the size of the block in the block header +// is wrong. +var errBlockSize = errors.New("xz: wrong uncompressed size for block") + +// Read reads data from the block. +func (br *BlockReader) Read(p []byte) (n int, err error) { + n, err = br.r.Read(p) + br.n += int64(n) + + u := br.header.UncompressedSize + if u >= 0 && br.uncompressedSize() > u { + return n, errors.New("xz: wrong uncompressed size for block") + } + c := br.header.CompressedSize + if c >= 0 && br.compressedSize() > c { + return n, errors.New("xz: wrong compressed size for block") + } + if err != io.EOF { + return n, err + } + if br.uncompressedSize() < u || br.compressedSize() < c { + return n, io.ErrUnexpectedEOF + } + + s := br.hash.Size() + k := padLen(br.lxz.n) + q := make([]byte, k+s, k+2*s) + if _, err = io.ReadFull(br.lxz.r, q); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return n, err + } + if !allZeros(q[:k]) { + return n, errors.New("xz: non-zero block padding") + } + checkSum := q[k:] + computedSum := br.hash.Sum(checkSum[s:]) + if !bytes.Equal(checkSum, computedSum) { + return n, errors.New("xz: checksum error for block") + } + return n, io.EOF +} + +// NewStreamReader creates a new xz stream reader using the given configuration +// parameters. NewReader reads and checks the header of the xz stream. +func NewStreamReader(xz io.Reader, dictCap int) (r *StreamReader, err error) { + data := make([]byte, HeaderLen) + if _, err := io.ReadFull(xz, data[:4]); err != nil { + return nil, err + } + if bytes.Equal(data[:4], []byte{0, 0, 0, 0}) { + return nil, ErrPadding + } + if _, err = io.ReadFull(xz, data[4:]); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + r = &StreamReader{ + dictCap: dictCap, + xz: xz, + index: make([]Record, 0, 4), + } + if err = r.h.UnmarshalBinary(data); err != nil { + return nil, err + } + xlog.Debugf("xz header %s", r.h) + if r.newHash, err = NewHashFunc(r.h.Flags); err != nil { + return nil, err + } + return r, nil +} diff --git a/xzinternals/writer.go b/xzinternals/writer.go new file mode 100644 index 0000000..2157729 --- /dev/null +++ b/xzinternals/writer.go @@ -0,0 +1,145 @@ +// Copyright 2014-2019 Ulrich Kunitz. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xzinternals + +import ( + "errors" + "hash" + "io" + + "github.com/ulikunitz/xz/filter" +) + +// countingWriter is a writer that counts all data written to it. +type countingWriter struct { + w io.Writer + n int64 +} + +func NewCountingWriter(wr io.Writer) countingWriter { + return countingWriter{w: wr} +} + +// Write writes data to the countingWriter. +func (cw *countingWriter) Write(p []byte) (n int, err error) { + n, err = cw.w.Write(p) + cw.n += int64(n) + if err == nil && cw.n < 0 { + return n, errors.New("xz: counter overflow") + } + return +} + +// BlockWriter is writes a single block. +type BlockWriter struct { + CXZ countingWriter + // MW combines io.WriteCloser w and the hash. + MW io.Writer + W io.WriteCloser + n int64 + BlockSize int64 + closed bool + headerLen int + + Filters []filter.Filter + Hash hash.Hash +} + +// WriteHeader writes the header. If the function is called after Close +// the commpressedSize and uncompressedSize fields will be filled. +func (bw *BlockWriter) WriteHeader(w io.Writer) error { + h := BlockHeader{ + CompressedSize: -1, + UncompressedSize: -1, + Filters: bw.Filters, + } + if bw.closed { + h.CompressedSize = bw.compressedSize() + h.UncompressedSize = bw.uncompressedSize() + } + data, err := h.MarshalBinary() + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + bw.headerLen = len(data) + return nil +} + +// compressed size returns the amount of data written to the underlying +// stream. +func (bw *BlockWriter) compressedSize() int64 { + return bw.CXZ.n +} + +// uncompressedSize returns the number of data written to the +// blockWriter +func (bw *BlockWriter) uncompressedSize() int64 { + return bw.n +} + +// unpaddedSize returns the sum of the header length, the uncompressed +// size of the block and the hash size. +func (bw *BlockWriter) unpaddedSize() int64 { + if bw.headerLen <= 0 { + panic("xz: block header not written") + } + n := int64(bw.headerLen) + n += bw.compressedSize() + n += int64(bw.Hash.Size()) + return n +} + +// Record returns the Record for the current stream. Call Close before +// calling this method. +func (bw *BlockWriter) Record() Record { + return Record{bw.unpaddedSize(), bw.uncompressedSize()} +} + +var ErrClosed = errors.New("xz: writer already closed") + +var ErrNoSpace = errors.New("xz: no space") + +// Write writes uncompressed data to the block writer. +func (bw *BlockWriter) Write(p []byte) (n int, err error) { + if bw.closed { + return 0, ErrClosed + } + + t := bw.BlockSize - bw.n + if int64(len(p)) > t { + err = ErrNoSpace + p = p[:t] + } + + var werr error + n, werr = bw.MW.Write(p) + bw.n += int64(n) + if werr != nil { + return n, werr + } + return n, err +} + +// Close closes the writer. +func (bw *BlockWriter) Close() error { + if bw.closed { + return ErrClosed + } + bw.closed = true + if err := bw.W.Close(); err != nil { + return err + } + s := bw.Hash.Size() + k := padLen(bw.CXZ.n) + p := make([]byte, k+s) + bw.Hash.Sum(p[k:k]) + if _, err := bw.CXZ.w.Write(p); err != nil { + return err + } + return nil +}