|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
from datetime import timedelta |
|
import random |
|
import pyarrow.fs as fs |
|
import pyarrow as pa |
|
|
|
import pytest |
|
|
|
encryption_unavailable = False |
|
|
|
try: |
|
import pyarrow.parquet as pq |
|
import pyarrow.dataset as ds |
|
except ImportError: |
|
pq = None |
|
ds = None |
|
|
|
try: |
|
from pyarrow.tests.parquet.encryption import InMemoryKmsClient |
|
import pyarrow.parquet.encryption as pe |
|
except ImportError: |
|
encryption_unavailable = True |
|
|
|
|
|
|
|
pytestmark = pytest.mark.dataset |
|
|
|
|
|
FOOTER_KEY = b"0123456789112345" |
|
FOOTER_KEY_NAME = "footer_key" |
|
COL_KEY = b"1234567890123450" |
|
COL_KEY_NAME = "col_key" |
|
|
|
|
|
def create_sample_table(): |
|
return pa.table( |
|
{ |
|
"year": [2020, 2022, 2021, 2022, 2019, 2021], |
|
"n_legs": [2, 2, 4, 4, 5, 100], |
|
"animal": [ |
|
"Flamingo", |
|
"Parrot", |
|
"Dog", |
|
"Horse", |
|
"Brittle stars", |
|
"Centipede", |
|
], |
|
} |
|
) |
|
|
|
|
|
def create_encryption_config(): |
|
return pe.EncryptionConfiguration( |
|
footer_key=FOOTER_KEY_NAME, |
|
plaintext_footer=False, |
|
column_keys={COL_KEY_NAME: ["n_legs", "animal"]}, |
|
encryption_algorithm="AES_GCM_V1", |
|
|
|
cache_lifetime=timedelta(minutes=5.0), |
|
data_key_length_bits=256, |
|
) |
|
|
|
|
|
def create_decryption_config(): |
|
return pe.DecryptionConfiguration(cache_lifetime=300) |
|
|
|
|
|
def create_kms_connection_config(): |
|
return pe.KmsConnectionConfig( |
|
custom_kms_conf={ |
|
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), |
|
COL_KEY_NAME: COL_KEY.decode("UTF-8"), |
|
} |
|
) |
|
|
|
|
|
def kms_factory(kms_connection_configuration): |
|
return InMemoryKmsClient(kms_connection_configuration) |
|
|
|
|
|
@pytest.mark.skipif( |
|
encryption_unavailable, reason="Parquet Encryption is not currently enabled" |
|
) |
|
def test_dataset_encryption_decryption(): |
|
table = create_sample_table() |
|
|
|
encryption_config = create_encryption_config() |
|
decryption_config = create_decryption_config() |
|
kms_connection_config = create_kms_connection_config() |
|
|
|
crypto_factory = pe.CryptoFactory(kms_factory) |
|
parquet_encryption_cfg = ds.ParquetEncryptionConfig( |
|
crypto_factory, kms_connection_config, encryption_config |
|
) |
|
parquet_decryption_cfg = ds.ParquetDecryptionConfig( |
|
crypto_factory, kms_connection_config, decryption_config |
|
) |
|
|
|
|
|
pformat = pa.dataset.ParquetFileFormat() |
|
write_options = pformat.make_write_options(encryption_config=parquet_encryption_cfg) |
|
|
|
mockfs = fs._MockFileSystem() |
|
mockfs.create_dir("/") |
|
|
|
ds.write_dataset( |
|
data=table, |
|
base_dir="sample_dataset", |
|
format=pformat, |
|
file_options=write_options, |
|
filesystem=mockfs, |
|
) |
|
|
|
|
|
pformat = pa.dataset.ParquetFileFormat() |
|
with pytest.raises(IOError, match=r"no decryption"): |
|
ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) |
|
|
|
|
|
pq_scan_opts = ds.ParquetFragmentScanOptions( |
|
decryption_config=parquet_decryption_cfg |
|
) |
|
pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) |
|
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) |
|
|
|
assert table.equals(dataset.to_table()) |
|
|
|
|
|
decryption_properties = crypto_factory.file_decryption_properties( |
|
kms_connection_config, decryption_config) |
|
pq_scan_opts = ds.ParquetFragmentScanOptions( |
|
decryption_properties=decryption_properties |
|
) |
|
|
|
pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) |
|
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) |
|
|
|
assert table.equals(dataset.to_table()) |
|
|
|
|
|
@pytest.mark.skipif( |
|
not encryption_unavailable, reason="Parquet Encryption is currently enabled" |
|
) |
|
def test_write_dataset_parquet_without_encryption(): |
|
"""Test write_dataset with ParquetFileFormat and test if an exception is thrown |
|
if you try to set encryption_config using make_write_options""" |
|
|
|
|
|
|
|
pformat = pa.dataset.ParquetFileFormat() |
|
|
|
with pytest.raises(NotImplementedError): |
|
_ = pformat.make_write_options(encryption_config="some value") |
|
|
|
|
|
@pytest.mark.skipif( |
|
encryption_unavailable, reason="Parquet Encryption is not currently enabled" |
|
) |
|
def test_large_row_encryption_decryption(): |
|
"""Test encryption and decryption of a large number of rows.""" |
|
|
|
class NoOpKmsClient(pe.KmsClient): |
|
def wrap_key(self, key_bytes: bytes, _: str) -> bytes: |
|
b = base64.b64encode(key_bytes) |
|
return b |
|
|
|
def unwrap_key(self, wrapped_key: bytes, _: str) -> bytes: |
|
b = base64.b64decode(wrapped_key) |
|
return b |
|
|
|
row_count = 2**15 + 1 |
|
table = pa.Table.from_arrays( |
|
[pa.array( |
|
[random.random() for _ in range(row_count)], |
|
type=pa.float32() |
|
)], names=["foo"] |
|
) |
|
|
|
kms_config = pe.KmsConnectionConfig() |
|
crypto_factory = pe.CryptoFactory(lambda _: NoOpKmsClient()) |
|
encryption_config = pe.EncryptionConfiguration( |
|
footer_key="UNIMPORTANT_KEY", |
|
column_keys={"UNIMPORTANT_KEY": ["foo"]}, |
|
double_wrapping=True, |
|
plaintext_footer=False, |
|
data_key_length_bits=128, |
|
) |
|
pqe_config = ds.ParquetEncryptionConfig( |
|
crypto_factory, kms_config, encryption_config |
|
) |
|
pqd_config = ds.ParquetDecryptionConfig( |
|
crypto_factory, kms_config, pe.DecryptionConfiguration() |
|
) |
|
scan_options = ds.ParquetFragmentScanOptions(decryption_config=pqd_config) |
|
file_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) |
|
write_options = file_format.make_write_options(encryption_config=pqe_config) |
|
file_decryption_properties = crypto_factory.file_decryption_properties(kms_config) |
|
|
|
mockfs = fs._MockFileSystem() |
|
mockfs.create_dir("/") |
|
|
|
path = "large-row-test-dataset" |
|
ds.write_dataset(table, path, format=file_format, |
|
file_options=write_options, filesystem=mockfs) |
|
|
|
file_path = path + "/part-0.parquet" |
|
new_table = pq.ParquetFile( |
|
file_path, decryption_properties=file_decryption_properties, |
|
filesystem=mockfs |
|
).read() |
|
assert table == new_table |
|
|
|
dataset = ds.dataset(path, format=file_format, filesystem=mockfs) |
|
new_table = dataset.to_table() |
|
assert table == new_table |
|
|