Skip to content

Commit

Permalink
Nate's Java Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nbauernfeind committed Oct 17, 2024
1 parent 3392d26 commit 3f24477
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.flatbuf.RecordBatch;
import org.apache.arrow.flatbuf.Schema;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -86,6 +87,23 @@ public static Schema parseArrowSchema(final BarrageProtoUtil.MessageInfo mi) {
return schema;
}

public static long[] extractBufferInfo(@NotNull final RecordBatch batch) {
final long[] bufferInfo = new long[batch.buffersLength()];
for (int bi = 0; bi < batch.buffersLength(); ++bi) {
int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset());
int length = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).length());

if (bi < batch.buffersLength() - 1) {
final int nextOffset =
LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi + 1).offset());
// our parsers handle overhanging buffers
length += Math.max(0, nextOffset - offset - length);
}
bufferInfo[bi] = length;
}
return bufferInfo;
}

@ScriptApi
public synchronized void setSchema(final ByteBuffer ipcMessage) {
// The input ByteBuffer instance (especially originated from Python) can't be assumed to be valid after the
Expand Down Expand Up @@ -185,19 +203,7 @@ protected BarrageMessage createBarrageMessage(BarrageProtoUtil.MessageInfo mi, i
new FlatBufferIteratorAdapter<>(batch.nodesLength(),
i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i)));

final long[] bufferInfo = new long[batch.buffersLength()];
for (int bi = 0; bi < batch.buffersLength(); ++bi) {
int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset());
int length = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).length());

if (bi < batch.buffersLength() - 1) {
final int nextOffset =
LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi + 1).offset());
// our parsers handle overhanging buffers
length += Math.max(0, nextOffset - offset - length);
}
bufferInfo[bi] = length;
}
final long[] bufferInfo = extractBufferInfo(batch);
final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator();

msg.rowsRemoved = RowSetFactory.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
import io.deephaven.generic.region.*;
import io.deephaven.io.log.impl.LogOutputStringImpl;
import io.deephaven.util.SafeCloseable;
import io.deephaven.util.SafeCloseableList;
import io.deephaven.util.annotations.ScriptApi;
import io.deephaven.util.datastructures.LongSizedDataStructure;
import org.apache.arrow.flatbuf.Message;
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.flatbuf.RecordBatch;
import org.apache.arrow.flatbuf.Schema;
Expand Down Expand Up @@ -122,26 +119,13 @@ public SchemaPair getTableSchema(
final ByteBuffer[] schemas =
pyTableDataService.call("_table_schema", tableKey.key).getObjectArrayValue(ByteBuffer.class);
final SchemaPair result = new SchemaPair();
result.tableSchema = convertSchema(schemas[0]);
result.partitionSchema = convertSchema(schemas[1]);
result.tableSchema = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(schemas[0])));
result.partitionSchema = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(schemas[1])));
return result;
}

private BarrageUtil.ConvertedArrowSchema convertSchema(final ByteBuffer original) {
// The Schema instance (especially originated from Python) can't be assumed to be valid after the return
// of this method. Until https://github.com/jpy-consortium/jpy/issues/126 is resolved, we need to make a
// copy of the header to use after the return of this method.

final BarrageProtoUtil.MessageInfo mi = parseArrowIpcMessage(original);
if (mi.header.headerType() != MessageHeader.Schema) {
throw new IllegalArgumentException("The input is not a valid Arrow Schema IPC message");
}
final Schema schema = new Schema();
Message.getRootAsMessage(mi.header.getByteBuffer()).header(schema);

return BarrageUtil.convertArrowSchema(schema);
}

/**
* Get the existing partitions for the table.
*
Expand Down Expand Up @@ -204,7 +188,7 @@ private void processNewPartition(
final ChunkType[] columnChunkTypes = arrowSchema.computeWireChunkTypes();
final Class<?>[] columnTypes = arrowSchema.computeWireTypes();
final Class<?>[] componentTypes = arrowSchema.computeWireComponentTypes();
for (int i = 0; i < schema.fieldsLength(); i++) {
for (int i = 0; i < columnTypes.length; i++) {
final int factor = (arrowSchema.conversionFactors == null) ? 1 : arrowSchema.conversionFactors[i];
ChunkReader reader = DefaultChunkReadingFactory.INSTANCE.getReader(
BarrageUtil.DEFAULT_SNAPSHOT_DESER_OPTIONS, factor,
Expand All @@ -216,43 +200,26 @@ private void processNewPartition(
if (recordBatchMessageInfo.header.headerType() != MessageHeader.RecordBatch) {
throw new IllegalArgumentException("byteBuffers[1] is not a valid Arrow RecordBatch IPC message");
}
// The BarrageProtoUtil.MessageInfo instance (especially originated from Python) can't be assumed to be
// valid
// after the return of this method. Until https://github.com/jpy-consortium/jpy/issues/126 is resolved, we
// need
// to make a copy of it to use after the return of this method.
final RecordBatch batch = (RecordBatch) recordBatchMessageInfo.header.header(new RecordBatch());

final Iterator<ChunkInputStreamGenerator.FieldNodeInfo> fieldNodeIter =
new FlatBufferIteratorAdapter<>(batch.nodesLength(),
i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i)));

final long[] bufferInfo = new long[batch.buffersLength()];
for (int bi = 0; bi < batch.buffersLength(); ++bi) {
int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset());
int length = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).length());

if (bi < batch.buffersLength() - 1) {
final int nextOffset =
LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi + 1).offset());
// our parsers handle overhanging buffers
length += Math.max(0, nextOffset - offset - length);
}
bufferInfo[bi] = length;
}
final long[] bufferInfo = ArrowToTableConverter.extractBufferInfo(batch);
final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator();

// populate the partition values
final int numColumns = schema.fieldsLength();
for (int ci = 0; ci < numColumns; ++ci) {
try (SafeCloseableList toClose = new SafeCloseableList()) {
final WritableChunk<Values> columnValues = readers.get(ci).readChunk(
fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0);
toClose.add(columnValues);
try (final WritableChunk<Values> columnValues = readers.get(ci).readChunk(
fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0)) {

if (columnValues.size() != 1) {
throw new IllegalArgumentException("Expected Single Row: found " + columnValues.size());
}

// partition values are always boxed to make partition value comparisons easier
try (final ChunkBoxer.BoxerKernel boxer =
ChunkBoxer.getBoxer(columnValues.getChunkType(), columnValues.size())) {
// noinspection unchecked
Expand Down Expand Up @@ -328,64 +295,43 @@ public List<WritableChunk<Values>> getColumnValues(
}

if (messages.length < 2) {
throw new IllegalArgumentException("Expected atleast two Arrow IPC messages: found "
throw new IllegalArgumentException("Expected at least two Arrow IPC messages: found "
+ messages.length);
}

final Schema schema = ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(messages[0]));
final BarrageUtil.ConvertedArrowSchema arrowSchema = BarrageUtil.convertArrowSchema(schema);

final ArrayList<ChunkReader> readers = new ArrayList<>();
final ChunkType[] columnChunkTypes = arrowSchema.computeWireChunkTypes();
final Class<?>[] columnTypes = arrowSchema.computeWireTypes();
final Class<?>[] componentTypes = arrowSchema.computeWireComponentTypes();
for (int i = 0; i < schema.fieldsLength(); i++) {
final int factor = (arrowSchema.conversionFactors == null) ? 1 : arrowSchema.conversionFactors[i];
ChunkReader reader = DefaultChunkReadingFactory.INSTANCE.getReader(
BarrageUtil.DEFAULT_SNAPSHOT_DESER_OPTIONS, factor,
typeInfo(columnChunkTypes[i], columnTypes[i], componentTypes[i], schema.fields(i)));
readers.add(reader);
if (schema.fieldsLength() > 1) {
throw new UnsupportedOperationException("More columns returned than requested.");
}

final int factor = (arrowSchema.conversionFactors == null) ? 1 : arrowSchema.conversionFactors[i];
final ChunkReader reader = DefaultChunkReadingFactory.INSTANCE.getReader(
BarrageUtil.DEFAULT_SNAPSHOT_DESER_OPTIONS, factor,
typeInfo(arrowSchema.computeWireChunkTypes()[0],
arrowSchema.computeWireTypes()[0],
arrowSchema.computeWireComponentTypes()[0],
schema.fields(0)));

try {
for (int ii = 1; ii < messages.length; ++ii) {
final BarrageProtoUtil.MessageInfo recordBatchMessageInfo = parseArrowIpcMessage(messages[ii]);
if (recordBatchMessageInfo.header.headerType() != MessageHeader.RecordBatch) {
throw new IllegalArgumentException(
"byteBuffers[1] is not a valid Arrow RecordBatch IPC message");
"byteBuffers[" + ii + "] is not a valid Arrow RecordBatch IPC message");
}
// The BarrageProtoUtil.MessageInfo instance (especially originated from Python) can't be
// assumed to be valid
// after the return of this method. Until https://github.com/jpy-consortium/jpy/issues/126 is
// resolved, we need
// to make a copy of it to use after the return of this method.
final RecordBatch batch = (RecordBatch) recordBatchMessageInfo.header.header(new RecordBatch());

final Iterator<ChunkInputStreamGenerator.FieldNodeInfo> fieldNodeIter =
new FlatBufferIteratorAdapter<>(batch.nodesLength(),
i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i)));

final long[] bufferInfo = new long[batch.buffersLength()];
for (int bi = 0; bi < batch.buffersLength(); ++bi) {
int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset());
int length = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).length());

if (bi < batch.buffersLength() - 1) {
final int nextOffset =
LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi + 1).offset());
// our parsers handle overhanging buffers
length += Math.max(0, nextOffset - offset - length);
}
bufferInfo[bi] = length;
}
final long[] bufferInfo = ArrowToTableConverter.extractBufferInfo(batch);
final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator();

if (schema.fieldsLength() > 1) {
throw new UnsupportedOperationException("More columns returned than requested.");
}

resultChunks.add(readers.get(0).readChunk(
resultChunks.add(reader.readChunk(
fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0));
}
} catch (final IOException unexpected) {
Expand All @@ -402,6 +348,8 @@ public List<WritableChunk<Values>> getColumnValues(
}
}



@Override
protected @NotNull TableLocationProvider makeTableLocationProvider(@NotNull final TableKey tableKey) {
if (!(tableKey instanceof TableKeyImpl)) {
Expand Down Expand Up @@ -592,7 +540,8 @@ public int compareTo(@NotNull final TableLocationKey other) {
throw new ClassCastException(String.format("Cannot compare %s to %s", getClass(), other.getClass()));
}
final TableLocationKeyImpl otherTableLocationKey = (TableLocationKeyImpl) other;
// TODO: What exactly is supposed to happen if partition values are equal but these are different locations?
// TODO NOCOMMIT @ryan: What exactly is supposed to happen if partition values are equal but these are
// different locations?
return PartitionsComparator.INSTANCE.compare(partitions, otherTableLocationKey.partitions);
}

Expand Down Expand Up @@ -633,7 +582,7 @@ private TableLocationImpl(
}

private synchronized void checkSizeChange(final long newSize) {
// TODO: should we throw if python tells us size decreased? or just ignore smaller sizes?
// TODO NOCOMMIT @ryan: should we throw if python tells us size decreased? or just ignore smaller sizes?
if (size >= newSize) {
return;
}
Expand All @@ -656,7 +605,7 @@ public void refresh() {

@Override
public @NotNull List<SortColumn> getSortedColumns() {
// TODO: we may be able to fetch this from the metadata or table definition post conversion
// TODO NOCOMMIT @ryan: we may be able to fetch this from the metadata or table definition post conversion
return List.of();
}

Expand Down

0 comments on commit 3f24477

Please sign in to comment.