|
import sys |
|
from dataclasses import dataclass |
|
from typing import TYPE_CHECKING, Optional, Union |
|
|
|
import pandas as pd |
|
import pyarrow as pa |
|
|
|
import datasets |
|
import datasets.config |
|
from datasets.features.features import require_storage_cast |
|
from datasets.table import table_cast |
|
|
|
|
|
if TYPE_CHECKING: |
|
import sqlite3 |
|
|
|
import sqlalchemy |
|
|
|
|
|
logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class SqlConfig(datasets.BuilderConfig): |
|
"""BuilderConfig for SQL.""" |
|
|
|
sql: Union[str, "sqlalchemy.sql.Selectable"] = None |
|
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None |
|
index_col: Optional[Union[str, list[str]]] = None |
|
coerce_float: bool = True |
|
params: Optional[Union[list, tuple, dict]] = None |
|
parse_dates: Optional[Union[list, dict]] = None |
|
columns: Optional[list[str]] = None |
|
chunksize: Optional[int] = 10_000 |
|
features: Optional[datasets.Features] = None |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
if self.sql is None: |
|
raise ValueError("sql must be specified") |
|
if self.con is None: |
|
raise ValueError("con must be specified") |
|
|
|
def create_config_id( |
|
self, |
|
config_kwargs: dict, |
|
custom_features: Optional[datasets.Features] = None, |
|
) -> str: |
|
config_kwargs = config_kwargs.copy() |
|
|
|
|
|
|
|
sql = config_kwargs["sql"] |
|
if not isinstance(sql, str): |
|
if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules: |
|
import sqlalchemy |
|
|
|
if isinstance(sql, sqlalchemy.sql.Selectable): |
|
engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") |
|
sql_str = str(sql.compile(dialect=engine.dialect)) |
|
config_kwargs["sql"] = sql_str |
|
else: |
|
raise TypeError( |
|
f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" |
|
) |
|
else: |
|
raise TypeError( |
|
f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" |
|
) |
|
con = config_kwargs["con"] |
|
if not isinstance(con, str): |
|
config_kwargs["con"] = id(con) |
|
logger.info( |
|
f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." |
|
) |
|
|
|
return super().create_config_id(config_kwargs, custom_features=custom_features) |
|
|
|
@property |
|
def pd_read_sql_kwargs(self): |
|
pd_read_sql_kwargs = { |
|
"index_col": self.index_col, |
|
"columns": self.columns, |
|
"params": self.params, |
|
"coerce_float": self.coerce_float, |
|
"parse_dates": self.parse_dates, |
|
} |
|
return pd_read_sql_kwargs |
|
|
|
|
|
class Sql(datasets.ArrowBasedBuilder): |
|
BUILDER_CONFIG_CLASS = SqlConfig |
|
|
|
def _info(self): |
|
return datasets.DatasetInfo(features=self.config.features) |
|
|
|
def _split_generators(self, dl_manager): |
|
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})] |
|
|
|
def _cast_table(self, pa_table: pa.Table) -> pa.Table: |
|
if self.config.features is not None: |
|
schema = self.config.features.arrow_schema |
|
if all(not require_storage_cast(feature) for feature in self.config.features.values()): |
|
|
|
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) |
|
else: |
|
|
|
pa_table = table_cast(pa_table, schema) |
|
return pa_table |
|
|
|
def _generate_tables(self): |
|
chunksize = self.config.chunksize |
|
sql_reader = pd.read_sql( |
|
self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs |
|
) |
|
sql_reader = [sql_reader] if chunksize is None else sql_reader |
|
for chunk_idx, df in enumerate(sql_reader): |
|
pa_table = pa.Table.from_pandas(df) |
|
yield chunk_idx, self._cast_table(pa_table) |
|
|