import sys from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union import pandas as pd import pyarrow as pa import datasets import datasets.config from datasets.features.features import require_storage_cast from datasets.table import table_cast if TYPE_CHECKING: import sqlite3 import sqlalchemy logger = datasets.utils.logging.get_logger(__name__) @dataclass class SqlConfig(datasets.BuilderConfig): """BuilderConfig for SQL.""" sql: Union[str, "sqlalchemy.sql.Selectable"] = None con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None index_col: Optional[Union[str, list[str]]] = None coerce_float: bool = True params: Optional[Union[list, tuple, dict]] = None parse_dates: Optional[Union[list, dict]] = None columns: Optional[list[str]] = None chunksize: Optional[int] = 10_000 features: Optional[datasets.Features] = None def __post_init__(self): super().__post_init__() if self.sql is None: raise ValueError("sql must be specified") if self.con is None: raise ValueError("con must be specified") def create_config_id( self, config_kwargs: dict, custom_features: Optional[datasets.Features] = None, ) -> str: config_kwargs = config_kwargs.copy() # We need to stringify the Selectable object to make its hash deterministic # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html sql = config_kwargs["sql"] if not isinstance(sql, str): if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules: import sqlalchemy if isinstance(sql, sqlalchemy.sql.Selectable): engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") sql_str = str(sql.compile(dialect=engine.dialect)) config_kwargs["sql"] = sql_str else: raise TypeError( f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" ) else: raise TypeError( f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" ) con = config_kwargs["con"] if not isinstance(con, str): config_kwargs["con"] = id(con) logger.info( f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." ) return super().create_config_id(config_kwargs, custom_features=custom_features) @property def pd_read_sql_kwargs(self): pd_read_sql_kwargs = { "index_col": self.index_col, "columns": self.columns, "params": self.params, "coerce_float": self.coerce_float, "parse_dates": self.parse_dates, } return pd_read_sql_kwargs class Sql(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = SqlConfig def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})] def _cast_table(self, pa_table: pa.Table) -> pa.Table: if self.config.features is not None: schema = self.config.features.arrow_schema if all(not require_storage_cast(feature) for feature in self.config.features.values()): # cheaper cast pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) else: # more expensive cast; allows str <-> int/float or str to Audio for example pa_table = table_cast(pa_table, schema) return pa_table def _generate_tables(self): chunksize = self.config.chunksize sql_reader = pd.read_sql( self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs ) sql_reader = [sql_reader] if chunksize is None else sql_reader for chunk_idx, df in enumerate(sql_reader): pa_table = pa.Table.from_pandas(df) yield chunk_idx, self._cast_table(pa_table)