Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Jan 4, 2025
1 parent 36623ac commit b493d4d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 89 deletions.
22 changes: 14 additions & 8 deletions parquet/src/arrow/arrow_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::arrow::array_reader::{build_array_reader, ArrayReader};
use crate::arrow::schema::{parquet_to_arrow_schema_and_fields, ParquetField};
use crate::arrow::{parquet_to_arrow_field_levels, FieldLevels, ProjectionMask};
use crate::column::page::{PageIterator, PageReader};
use crate::data_type::AsBytes;
use crate::errors::{ParquetError, Result};
use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader};
use crate::file::reader::{ChunkReader, SerializedPageReader};
Expand All @@ -41,6 +42,7 @@ mod filter;
mod selection;
pub mod statistics;

use crate::encryption::ciphers::FileDecryptor;
#[cfg(feature = "encryption")]
use crate::encryption::ciphers::{CryptoContext, FileDecryptionProperties};

Expand Down Expand Up @@ -709,13 +711,17 @@ impl<T: ChunkReader + 'static> Iterator for ReaderPageIterator<T> {
#[cfg(feature = "encryption")]
let crypto_context = if self.metadata.file_decryptor().is_some() {
let file_decryptor = Arc::new(self.metadata.file_decryptor().clone().unwrap());

let crypto_context = CryptoContext::new(
rg_idx,
self.column_idx,
file_decryptor.clone(),
file_decryptor,
);
let metadata_decryptor = Arc::new(self.metadata.file_decryptor().clone().unwrap());
let column_name = self
.metadata
.file_metadata()
.schema_descr()
.column(self.column_idx);
let data_decryptor =
Arc::new(file_decryptor.get_column_decryptor(column_name.name().as_bytes()));

let crypto_context =
CryptoContext::new(rg_idx, self.column_idx, metadata_decryptor, data_decryptor);
Some(Arc::new(crypto_context))
} else {
None
Expand Down Expand Up @@ -1763,7 +1769,7 @@ mod tests {
assert_eq!(file_metadata.schema_descr().num_columns(), 8);
assert_eq!(
file_metadata.created_by().unwrap(),
"parquet-cpp-arrow version 14.0.0-SNAPSHOT"
"parquet-cpp-arrow version 19.0.0-SNAPSHOT"
);

metadata.metadata.row_groups().iter().for_each(|rg| {
Expand Down
23 changes: 20 additions & 3 deletions parquet/src/encryption/ciphers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,23 @@ impl FileDecryptor {
self.footer_decryptor.unwrap()
}

pub(crate) fn get_column_decryptor(&self, column_name: &[u8]) -> RingGcmBlockDecryptor {
pub(crate) fn get_column_decryptor(&self, column_name: &[u8]) -> FileDecryptor {
if self.decryption_properties.column_keys.is_none() {
return self.clone();
}
let column_keys = &self.decryption_properties.column_keys.clone().unwrap();
let column_key = column_keys[&column_name.to_vec()].clone();
RingGcmBlockDecryptor::new(&column_key)
let decryptor = if let Some(column_key) = column_keys.get(column_name) {
Some(RingGcmBlockDecryptor::new(&column_key))
} else {
None
};

FileDecryptor {
decryption_properties: self.decryption_properties.clone(),
footer_decryptor: decryptor,
aad_file_unique: self.aad_file_unique.clone(),
aad_prefix: self.aad_prefix.clone(),
}
}

pub(crate) fn decryption_properties(&self) -> &FileDecryptionProperties {
Expand All @@ -346,6 +359,10 @@ impl FileDecryptor {
pub(crate) fn aad_prefix(&self) -> &Vec<u8> {
&self.aad_prefix
}

pub(crate) fn has_footer_key(&self) -> bool {
self.decryption_properties.has_footer_key()
}
}

#[derive(Debug, Clone)]
Expand Down
76 changes: 55 additions & 21 deletions parquet/src/file/metadata/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,35 @@ mod memory;
pub(crate) mod reader;
mod writer;

use std::ops::Range;
use std::sync::Arc;
use zstd::zstd_safe::WriteBuf;
use crate::format::{BoundaryOrder, ColumnChunk, ColumnCryptoMetaData, ColumnIndex, ColumnMetaData, OffsetIndex, PageLocation, RowGroup, SizeStatistics, SortingColumn};
use crate::encryption::ciphers::{create_footer_aad, create_page_aad, ModuleType};
use crate::basic::{ColumnOrder, Compression, Encoding, Type};
use crate::data_type::AsBytes;
#[cfg(feature = "encryption")]
use crate::encryption::ciphers::FileDecryptor;
use crate::encryption::ciphers::{create_footer_aad, create_page_aad, ModuleType};
use crate::encryption::ciphers::{
BlockDecryptor, DecryptionPropertiesBuilder, FileDecryptionProperties,
};
use crate::errors::{ParquetError, Result};
pub(crate) use crate::file::metadata::memory::HeapSize;
use crate::file::page_encoding_stats::{self, PageEncodingStats};
use crate::file::page_index::index::Index;
use crate::file::page_index::offset_index::OffsetIndexMetaData;
use crate::file::statistics::{self, Statistics};
use crate::format::{
BoundaryOrder, ColumnChunk, ColumnCryptoMetaData, ColumnIndex, ColumnMetaData, OffsetIndex,
PageLocation, RowGroup, SizeStatistics, SortingColumn,
};
use crate::schema::types::{
ColumnDescPtr, ColumnDescriptor, ColumnPath, SchemaDescPtr, SchemaDescriptor,
Type as SchemaType,
};
use crate::thrift::{TCompactSliceInputProtocol, TSerializable};
pub use reader::ParquetMetaDataReader;
use std::ops::Range;
use std::sync::Arc;
pub use writer::ParquetMetaDataWriter;
pub(crate) use writer::ThriftMetadataWriter;
use crate::data_type::AsBytes;
use crate::encryption::ciphers::{BlockDecryptor, DecryptionPropertiesBuilder, FileDecryptionProperties};
use crate::thrift::{TCompactSliceInputProtocol, TSerializable};
use zstd::zstd_safe::WriteBuf;

/// Page level statistics for each column chunk of each row group.
///
Expand Down Expand Up @@ -636,7 +641,11 @@ impl RowGroupMetaData {
}

/// Method to convert from Thrift.
pub fn from_thrift(schema_descr: SchemaDescPtr, mut rg: RowGroup, #[cfg(feature = "encryption")] decryptor: Option<&FileDecryptor>) -> Result<RowGroupMetaData> {
pub fn from_thrift(
schema_descr: SchemaDescPtr,
mut rg: RowGroup,
#[cfg(feature = "encryption")] decryptor: Option<&FileDecryptor>,
) -> Result<RowGroupMetaData> {
if schema_descr.num_columns() != rg.columns.len() {
return Err(general_err!(
"Column count mismatch. Schema has {} columns while Row Group has {}",
Expand All @@ -647,17 +656,30 @@ impl RowGroupMetaData {
let total_byte_size = rg.total_byte_size;
let num_rows = rg.num_rows;
let mut columns = vec![];
for (i, (c, d)) in rg.columns.drain(0..).zip(schema_descr.columns()).enumerate() {
for (i, (c, d)) in rg
.columns
.drain(0..)
.zip(schema_descr.columns())
.enumerate()
{
let cc;
#[cfg(feature = "encryption")]
if let Some(ColumnCryptoMetaData::ENCRYPTIONWITHCOLUMNKEY(crypto_metadata)) = c.crypto_metadata.clone() {
if let Some(ColumnCryptoMetaData::ENCRYPTIONWITHCOLUMNKEY(crypto_metadata)) =
c.crypto_metadata.clone()
{
if decryptor.is_none() {
cc = ColumnChunkMetaData::from_thrift(d.clone(), c)?;
} else {
let column_name = crypto_metadata.path_in_schema.join(".");
let column_decryptor = decryptor.unwrap().get_column_decryptor(column_name.as_bytes());
let column_decryptor = decryptor
.unwrap()
.get_column_decryptor(column_name.as_bytes());
let aad_file_unique = decryptor.unwrap().aad_file_unique();
let aad_prefix = decryptor.unwrap().decryption_properties().aad_prefix().unwrap();
let aad_prefix = decryptor
.unwrap()
.decryption_properties()
.aad_prefix()
.unwrap();
let aad = [aad_prefix.clone(), aad_file_unique.clone()].concat();
// let s = aad.as_slice();
let column_aad = create_page_aad(
Expand All @@ -673,7 +695,10 @@ impl RowGroupMetaData {
// let mut prot = TCompactSliceInputProtocol::new(buf.as_slice());
// decrypted_fmd_buf =
// footer_decryptor.decrypt(prot.as_slice().as_ref(), aad_footer.as_ref())?;
let mut c2 = column_decryptor.decrypt(buf.as_ref(), column_aad.as_ref())?;
let mut c2 = column_decryptor
.footer_decryptor()
.unwrap()
.decrypt(buf.as_ref(), column_aad.as_ref())?;
let mut prot = TCompactSliceInputProtocol::new(c2.as_slice());
let c3 = ColumnChunk::read_from_in_protocol(&mut prot)?;
// let md = ColumnMetaData::from_thrift(c2, d.clone())?;
Expand Down Expand Up @@ -1678,9 +1703,14 @@ mod tests {
.unwrap();

let row_group_exp = row_group_meta.to_thrift();
let row_group_res = RowGroupMetaData::from_thrift(schema_descr, row_group_exp.clone(), #[cfg(feature = "encryption")] None)
.unwrap()
.to_thrift();
let row_group_res = RowGroupMetaData::from_thrift(
schema_descr,
row_group_exp.clone(),
#[cfg(feature = "encryption")]
None,
)
.unwrap()
.to_thrift();

assert_eq!(row_group_res, row_group_exp);
}
Expand Down Expand Up @@ -1759,10 +1789,14 @@ mod tests {
.build()
.unwrap();

let err =
RowGroupMetaData::from_thrift(schema_descr_3cols, row_group_meta_2cols.to_thrift(), #[cfg(feature = "encryption")] None)
.unwrap_err()
.to_string();
let err = RowGroupMetaData::from_thrift(
schema_descr_3cols,
row_group_meta_2cols.to_thrift(),
#[cfg(feature = "encryption")]
None,
)
.unwrap_err()
.to_string();
assert_eq!(
err,
"Parquet error: Column count mismatch. Schema has 3 columns while Row Group has 2"
Expand Down
7 changes: 6 additions & 1 deletion parquet/src/file/metadata/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,12 @@ impl ParquetMetaDataReader {
let mut row_groups = Vec::new();
// TODO: row group filtering
for rg in t_file_metadata.row_groups {
row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg, #[cfg(feature = "encryption")] decryptor.as_ref())?);
row_groups.push(RowGroupMetaData::from_thrift(
schema_descr.clone(),
rg,
#[cfg(feature = "encryption")]
decryptor.as_ref(),
)?);
}
let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr);

Expand Down
125 changes: 69 additions & 56 deletions parquet/src/file/serialized_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::basic::{Encoding, Type};
use crate::bloom_filter::Sbbf;
use crate::column::page::{Page, PageMetadata, PageReader};
use crate::compression::{create_codec, Codec};
use crate::encryption::ciphers::RingGcmBlockDecryptor;
#[cfg(feature = "encryption")]
use crate::encryption::ciphers::{create_page_aad, BlockDecryptor, CryptoContext, ModuleType};
use crate::errors::{ParquetError, Result};
Expand All @@ -42,7 +43,6 @@ use std::collections::VecDeque;
use std::iter;
use std::{fs::File, io::Read, path::Path, sync::Arc};
use thrift::protocol::TCompactInputProtocol;
use crate::encryption::ciphers::RingGcmBlockDecryptor;

impl TryFrom<File> for SerializedFileReader<File> {
type Error = ParquetError;
Expand Down Expand Up @@ -347,42 +347,44 @@ pub(crate) fn read_page_header<T: Read>(
#[cfg(feature = "encryption")] crypto_context: Option<Arc<CryptoContext>>,
) -> Result<PageHeader> {
#[cfg(feature = "encryption")]
// if let Some(crypto_context) = crypto_context {
// // crypto_context.data_decryptor().get_column_decryptor()
// let decryptor = &crypto_context.data_decryptor();
// // todo: get column decryptor
// // let file_decryptor = decryptor.ge(crypto_context.column_ordinal);
// // if !decryptor.decryption_properties().has_footer_key() {
// // return Err(general_err!("Missing footer decryptor"));
// // }
// let file_decryptor = decryptor.footer_decryptor();
// let aad_file_unique = decryptor.aad_file_unique();
// let aad_prefix = decryptor.aad_prefix();
//
// let module_type = if crypto_context.dictionary_page {
// ModuleType::DictionaryPageHeader
// } else {
// ModuleType::DataPageHeader
// };
// let aad = create_page_aad(
// [aad_prefix.as_slice(), aad_file_unique.as_slice()].concat().as_slice(),
// module_type,
// crypto_context.row_group_ordinal,
// crypto_context.column_ordinal,
// crypto_context.page_ordinal,
// )?;
//
// let mut len_bytes = [0; 4];
// input.read_exact(&mut len_bytes)?;
// let ciphertext_len = u32::from_le_bytes(len_bytes) as usize;
// let mut ciphertext = vec![0; 4 + ciphertext_len];
// input.read_exact(&mut ciphertext[4..])?;
// let buf = file_decryptor.unwrap().decrypt(&ciphertext, aad.as_ref())?;
//
// let mut prot = TCompactSliceInputProtocol::new(buf.as_slice());
// let page_header = PageHeader::read_from_in_protocol(&mut prot)?;
// return Ok(page_header);
// }
if let Some(crypto_context) = crypto_context {
// crypto_context.data_decryptor().get_column_decryptor()
let decryptor = &crypto_context.data_decryptor();
// todo: get column decryptor
// let file_decryptor = decryptor.ge(crypto_context.column_ordinal);
// if !decryptor.decryption_properties().has_footer_key() {
// return Err(general_err!("Missing footer decryptor"));
// }
let file_decryptor = decryptor.footer_decryptor();
let aad_file_unique = decryptor.aad_file_unique();
let aad_prefix = decryptor.aad_prefix();

let module_type = if crypto_context.dictionary_page {
ModuleType::DictionaryPageHeader
} else {
ModuleType::DataPageHeader
};
let aad = create_page_aad(
[aad_prefix.as_slice(), aad_file_unique.as_slice()]
.concat()
.as_slice(),
module_type,
crypto_context.row_group_ordinal,
crypto_context.column_ordinal,
crypto_context.page_ordinal,
)?;

let mut len_bytes = [0; 4];
input.read_exact(&mut len_bytes)?;
let ciphertext_len = u32::from_le_bytes(len_bytes) as usize;
let mut ciphertext = vec![0; 4 + ciphertext_len];
input.read_exact(&mut ciphertext[4..])?;
let buf = file_decryptor.unwrap().decrypt(&ciphertext, aad.as_ref())?;

let mut prot = TCompactSliceInputProtocol::new(buf.as_slice());
let page_header = PageHeader::read_from_in_protocol(&mut prot)?;
return Ok(page_header);
}

let mut prot = TCompactInputProtocol::new(input);
let page_header = PageHeader::read_from_in_protocol(&mut prot)?;
Expand Down Expand Up @@ -457,27 +459,38 @@ pub(crate) fn decode_page(
let buffer: Bytes = if crypto_context.is_some() {
let crypto_context = crypto_context.as_ref().unwrap();
let decryptor = crypto_context.data_decryptor();
let Some(file_decryptor) = if let Some(f) = decryptor.footer_decryptor().clone() {
// Some(RingGcmBlockDecryptor::new(decryptor..as_ref()))
} else {
decryptor.
};
// let footer_decryptor
// let file_decryptor = if decryptor.has_footer_key() {
// decryptor.footer_decryptor()
// } else {
// // CryptoMetaData::from_thrift(&crypto_context.meta_data)
// // .and_then(|meta| meta.get_page_decryptor(crypto_context.page_ordinal))
// // .ok_or_else(|| general_err!("Missing footer decryptor"))?
// // page_header.data_page_header
// // decryptor.get_column_decryptor(crypto_context.column_ordinal)
// // decryptor.get_column_decryptor(crypto_context.column_ordinal)
// return Err(general_err!("Missing footer decryptor"));
// // TODO: decryptor should have keys for columns
// };
let file_decryptor = decryptor.footer_decryptor();

let module_type = if crypto_context.dictionary_page {
ModuleType::DictionaryPage
if file_decryptor.is_none() {
buffer
} else {
ModuleType::DataPage
};
let aad = create_page_aad(
decryptor.aad_file_unique().as_slice(),
module_type,
crypto_context.row_group_ordinal,
crypto_context.column_ordinal,
crypto_context.page_ordinal,
)?;
let decrypted = file_decryptor.unwrap().decrypt(&buffer.as_ref(), &aad)?;
Bytes::from(decrypted)
let module_type = if crypto_context.dictionary_page {
ModuleType::DictionaryPage
} else {
ModuleType::DataPage
};
let aad = create_page_aad(
decryptor.aad_file_unique().as_slice(),
module_type,
crypto_context.row_group_ordinal,
crypto_context.column_ordinal,
crypto_context.page_ordinal,
)?;
let decrypted = file_decryptor.unwrap().decrypt(&buffer.as_ref(), &aad)?;
Bytes::from(decrypted)
}
} else {
buffer
};
Expand Down

0 comments on commit b493d4d

Please sign in to comment.