jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from datetime import timedelta
import pyarrow as pa
try:
import pyarrow.parquet as pq
import pyarrow.parquet.encryption as pe
except ImportError:
pq = None
pe = None
else:
from pyarrow.tests.parquet.encryption import (
InMemoryKmsClient, verify_file_encrypted)
PARQUET_NAME = 'encrypted_table.in_mem.parquet'
FOOTER_KEY = b"0123456789112345"
FOOTER_KEY_NAME = "footer_key"
COL_KEY = b"1234567890123450"
COL_KEY_NAME = "col_key"
# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not parquet_encryption'
# Ignore these with pytest ... -m 'not parquet'
pytestmark = [
pytest.mark.parquet_encryption,
pytest.mark.parquet
]
@pytest.fixture(scope='module')
def data_table():
data_table = pa.Table.from_pydict({
'a': pa.array([1, 2, 3]),
'b': pa.array(['a', 'b', 'c']),
'c': pa.array(['x', 'y', 'z'])
})
return data_table
@pytest.fixture(scope='module')
def basic_encryption_config():
basic_encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
})
return basic_encryption_config
def setup_encryption_environment(custom_kms_conf):
"""
Sets up and returns the KMS connection configuration and crypto factory
based on provided KMS configuration parameters.
"""
kms_connection_config = pe.KmsConnectionConfig(custom_kms_conf=custom_kms_conf)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
# Create our CryptoFactory
crypto_factory = pe.CryptoFactory(kms_factory)
return kms_connection_config, crypto_factory
def write_encrypted_file(path, data_table, footer_key_name, col_key_name,
footer_key, col_key, encryption_config):
"""
Writes an encrypted parquet file based on the provided parameters.
"""
# Setup the custom KMS configuration with provided keys
custom_kms_conf = {
footer_key_name: footer_key.decode("UTF-8"),
col_key_name: col_key.decode("UTF-8"),
}
# Setup encryption environment
kms_connection_config, crypto_factory = setup_encryption_environment(
custom_kms_conf)
# Write the encrypted parquet file
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
return kms_connection_config, crypto_factory
def test_encrypted_parquet_write_read(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted, and then read it."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
},
encryption_algorithm="AES_GCM_V1",
cache_lifetime=timedelta(minutes=5.0),
data_key_length_bits=256)
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
encryption_config)
verify_file_encrypted(path)
# Read with decryption properties
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
result_table = read_encrypted_parquet(
path, decryption_config, kms_connection_config, crypto_factory)
assert data_table.equals(result_table)
def write_encrypted_parquet(path, table, encryption_config,
kms_connection_config, crypto_factory):
file_encryption_properties = crypto_factory.file_encryption_properties(
kms_connection_config, encryption_config)
assert file_encryption_properties is not None
with pq.ParquetWriter(
path, table.schema,
encryption_properties=file_encryption_properties) as writer:
writer.write_table(table)
def read_encrypted_parquet(path, decryption_config,
kms_connection_config, crypto_factory):
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
assert file_decryption_properties is not None
meta = pq.read_metadata(
path, decryption_properties=file_decryption_properties)
assert meta.num_columns == 3
schema = pq.read_schema(
path, decryption_properties=file_decryption_properties)
assert len(schema.names) == 3
result = pq.ParquetFile(
path, decryption_properties=file_decryption_properties)
return result.read(use_threads=True)
def test_encrypted_parquet_write_read_wrong_key(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
and then read it using wrong keys."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
},
encryption_algorithm="AES_GCM_V1",
cache_lifetime=timedelta(minutes=5.0),
data_key_length_bits=256)
write_encrypted_file(path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME,
FOOTER_KEY, COL_KEY, encryption_config)
verify_file_encrypted(path)
wrong_kms_connection_config, wrong_crypto_factory = setup_encryption_environment({
FOOTER_KEY_NAME: COL_KEY.decode("UTF-8"), # Intentionally wrong
COL_KEY_NAME: FOOTER_KEY.decode("UTF-8"), # Intentionally wrong
})
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
with pytest.raises(ValueError, match=r"Incorrect master key used"):
read_encrypted_parquet(
path, decryption_config, wrong_kms_connection_config,
wrong_crypto_factory)
def test_encrypted_parquet_read_no_decryption_config(tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
but then try to read it without decryption properties."""
test_encrypted_parquet_write_read(tempdir, data_table)
# Read without decryption properties
with pytest.raises(IOError, match=r"no decryption"):
pq.ParquetFile(tempdir / PARQUET_NAME).read()
def test_encrypted_parquet_read_metadata_no_decryption_config(
tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
but then try to read its metadata without decryption properties."""
test_encrypted_parquet_write_read(tempdir, data_table)
# Read metadata without decryption properties
with pytest.raises(IOError, match=r"no decryption"):
pq.read_metadata(tempdir / PARQUET_NAME)
def test_encrypted_parquet_read_schema_no_decryption_config(
tempdir, data_table):
"""Write an encrypted parquet, verify it's encrypted,
but then try to read its schema without decryption properties."""
test_encrypted_parquet_write_read(tempdir, data_table)
with pytest.raises(IOError, match=r"no decryption"):
pq.read_schema(tempdir / PARQUET_NAME)
def test_encrypted_parquet_write_no_col_key(tempdir, data_table):
"""Write an encrypted parquet, but give only footer key,
without column key."""
path = tempdir / 'encrypted_table_no_col_key.in_mem.parquet'
# Encrypt the footer with the footer key
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME)
with pytest.raises(OSError,
match="Either column_keys or uniform_encryption "
"must be set"):
# Write with encryption properties
write_encrypted_file(path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME,
FOOTER_KEY, b"", encryption_config)
def test_encrypted_parquet_write_kms_error(tempdir, data_table,
basic_encryption_config):
"""Write an encrypted parquet, but raise KeyError in KmsClient."""
path = tempdir / 'encrypted_table_kms_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
def kms_factory(kms_connection_configuration):
# Empty master keys map will cause KeyError to be raised
# on wrap/unwrap calls
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(KeyError, match="footer_key"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_specific_error(tempdir, data_table,
basic_encryption_config):
"""Write an encrypted parquet, but raise KeyError in KmsClient."""
path = tempdir / 'encrypted_table_kms_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
class ThrowingKmsClient(pe.KmsClient):
"""A KmsClient implementation that throws exception in
wrap/unwrap calls
"""
def __init__(self, config):
"""Create an InMemoryKmsClient instance."""
pe.KmsClient.__init__(self)
self.config = config
def wrap_key(self, key_bytes, master_key_identifier):
raise ValueError("Cannot Wrap Key")
def unwrap_key(self, wrapped_key, master_key_identifier):
raise ValueError("Cannot Unwrap Key")
def kms_factory(kms_connection_configuration):
# Exception thrown in wrap/unwrap calls
return ThrowingKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(ValueError, match="Cannot Wrap Key"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_factory_error(tempdir, data_table,
basic_encryption_config):
"""Write an encrypted parquet, but raise ValueError in kms_factory."""
path = tempdir / 'encrypted_table_kms_factory_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
def kms_factory(kms_connection_configuration):
raise ValueError('Cannot create KmsClient')
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(ValueError,
match="Cannot create KmsClient"):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_write_kms_factory_type_error(
tempdir, data_table, basic_encryption_config):
"""Write an encrypted parquet, but use wrong KMS client type
that doesn't implement KmsClient."""
path = tempdir / 'encrypted_table_kms_factory_error.in_mem.parquet'
encryption_config = basic_encryption_config
# Empty master_keys_map
kms_connection_config = pe.KmsConnectionConfig()
class WrongTypeKmsClient():
"""This is not an implementation of KmsClient.
"""
def __init__(self, config):
self.master_keys_map = config.custom_kms_conf
def wrap_key(self, key_bytes, master_key_identifier):
return None
def unwrap_key(self, wrapped_key, master_key_identifier):
return None
def kms_factory(kms_connection_configuration):
return WrongTypeKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
with pytest.raises(TypeError):
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_encryption_configuration():
def validate_encryption_configuration(encryption_config):
assert FOOTER_KEY_NAME == encryption_config.footer_key
assert ["a", "b"] == encryption_config.column_keys[COL_KEY_NAME]
assert "AES_GCM_CTR_V1" == encryption_config.encryption_algorithm
assert encryption_config.plaintext_footer
assert not encryption_config.double_wrapping
assert timedelta(minutes=10.0) == encryption_config.cache_lifetime
assert not encryption_config.internal_key_material
assert 192 == encryption_config.data_key_length_bits
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={COL_KEY_NAME: ["a", "b"], },
encryption_algorithm="AES_GCM_CTR_V1",
plaintext_footer=True,
double_wrapping=False,
cache_lifetime=timedelta(minutes=10.0),
internal_key_material=False,
data_key_length_bits=192,
)
validate_encryption_configuration(encryption_config)
encryption_config_1 = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME)
encryption_config_1.column_keys = {COL_KEY_NAME: ["a", "b"], }
encryption_config_1.encryption_algorithm = "AES_GCM_CTR_V1"
encryption_config_1.plaintext_footer = True
encryption_config_1.double_wrapping = False
encryption_config_1.cache_lifetime = timedelta(minutes=10.0)
encryption_config_1.internal_key_material = False
encryption_config_1.data_key_length_bits = 192
validate_encryption_configuration(encryption_config_1)
def test_encrypted_parquet_decryption_configuration():
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=10.0))
assert timedelta(minutes=10.0) == decryption_config.cache_lifetime
decryption_config_1 = pe.DecryptionConfiguration()
decryption_config_1.cache_lifetime = timedelta(minutes=10.0)
assert timedelta(minutes=10.0) == decryption_config_1.cache_lifetime
def test_encrypted_parquet_kms_configuration():
def validate_kms_connection_config(kms_connection_config):
assert "Instance1" == kms_connection_config.kms_instance_id
assert "URL1" == kms_connection_config.kms_instance_url
assert "MyToken" == kms_connection_config.key_access_token
assert ({"key1": "key_material_1", "key2": "key_material_2"} ==
kms_connection_config.custom_kms_conf)
kms_connection_config = pe.KmsConnectionConfig(
kms_instance_id="Instance1",
kms_instance_url="URL1",
key_access_token="MyToken",
custom_kms_conf={
"key1": "key_material_1",
"key2": "key_material_2",
})
validate_kms_connection_config(kms_connection_config)
kms_connection_config_1 = pe.KmsConnectionConfig()
kms_connection_config_1.kms_instance_id = "Instance1"
kms_connection_config_1.kms_instance_url = "URL1"
kms_connection_config_1.key_access_token = "MyToken"
kms_connection_config_1.custom_kms_conf = {
"key1": "key_material_1",
"key2": "key_material_2",
}
validate_kms_connection_config(kms_connection_config_1)
@pytest.mark.xfail(reason="Plaintext footer - reading plaintext column subset"
" reads encrypted columns too")
def test_encrypted_parquet_write_read_plain_footer_single_wrapping(
tempdir, data_table):
"""Write an encrypted parquet, with plaintext footer
and with single wrapping,
verify it's encrypted, and then read plaintext columns."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={
COL_KEY_NAME: ["a", "b"],
},
plaintext_footer=True,
double_wrapping=False)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
# # Read without decryption properties only the plaintext column
# result = pq.ParquetFile(path)
# result_table = result.read(columns='c', use_threads=False)
# assert table.num_rows == result_table.num_rows
@pytest.mark.xfail(reason="External key material not supported yet")
def test_encrypted_parquet_write_external(tempdir, data_table):
"""Write an encrypted parquet, with external key
material.
Currently it's not implemented, so should throw
an exception"""
path = tempdir / PARQUET_NAME
# Encrypt the file with the footer key
encryption_config = pe.EncryptionConfiguration(
footer_key=FOOTER_KEY_NAME,
column_keys={},
internal_key_material=False)
kms_connection_config = pe.KmsConnectionConfig(
custom_kms_conf={FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8")}
)
def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)
crypto_factory = pe.CryptoFactory(kms_factory)
# Write with encryption properties
write_encrypted_parquet(path, data_table, encryption_config,
kms_connection_config, crypto_factory)
def test_encrypted_parquet_loop(tempdir, data_table, basic_encryption_config):
"""Write an encrypted parquet, verify it's encrypted,
and then read it multithreaded in a loop."""
path = tempdir / PARQUET_NAME
# Encrypt the footer with the footer key,
# encrypt column `a` and column `b` with another key,
# keep `c` plaintext, defined in basic_encryption_config
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
basic_encryption_config)
verify_file_encrypted(path)
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
for i in range(50):
# Read with decryption properties
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
assert file_decryption_properties is not None
result = pq.ParquetFile(
path, decryption_properties=file_decryption_properties)
result_table = result.read(use_threads=True)
assert data_table.equals(result_table)
def test_read_with_deleted_crypto_factory(tempdir, data_table, basic_encryption_config):
"""
Test that decryption properties can be used if the crypto factory is no longer alive
"""
path = tempdir / PARQUET_NAME
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
basic_encryption_config)
verify_file_encrypted(path)
# Create decryption properties and delete the crypto factory that created
# the properties afterwards.
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
del crypto_factory
result = pq.ParquetFile(
path, decryption_properties=file_decryption_properties)
result_table = result.read(use_threads=True)
assert data_table.equals(result_table)
def test_encrypted_parquet_read_table(tempdir, data_table, basic_encryption_config):
"""Write an encrypted parquet then read it back using read_table."""
path = tempdir / PARQUET_NAME
# Write the encrypted parquet file using the utility function
kms_connection_config, crypto_factory = write_encrypted_file(
path, data_table, FOOTER_KEY_NAME, COL_KEY_NAME, FOOTER_KEY, COL_KEY,
basic_encryption_config)
decryption_config = pe.DecryptionConfiguration(
cache_lifetime=timedelta(minutes=5.0))
file_decryption_properties = crypto_factory.file_decryption_properties(
kms_connection_config, decryption_config)
# Read the encrypted parquet file using read_table
result_table = pq.read_table(path, decryption_properties=file_decryption_properties)
# Assert that the read table matches the original data
assert data_table.equals(result_table)
# Read the encrypted parquet folder using read_table
result_table = pq.read_table(
tempdir, decryption_properties=file_decryption_properties)
assert data_table.equals(result_table)