File size: 4,495 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union

import pandas as pd
import pyarrow as pa

import datasets
import datasets.config
from datasets.features.features import require_storage_cast
from datasets.table import table_cast


if TYPE_CHECKING:
    import sqlite3

    import sqlalchemy


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class SqlConfig(datasets.BuilderConfig):
    """BuilderConfig for SQL."""

    sql: Union[str, "sqlalchemy.sql.Selectable"] = None
    con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None
    index_col: Optional[Union[str, list[str]]] = None
    coerce_float: bool = True
    params: Optional[Union[list, tuple, dict]] = None
    parse_dates: Optional[Union[list, dict]] = None
    columns: Optional[list[str]] = None
    chunksize: Optional[int] = 10_000
    features: Optional[datasets.Features] = None

    def __post_init__(self):
        super().__post_init__()
        if self.sql is None:
            raise ValueError("sql must be specified")
        if self.con is None:
            raise ValueError("con must be specified")

    def create_config_id(
        self,
        config_kwargs: dict,
        custom_features: Optional[datasets.Features] = None,
    ) -> str:
        config_kwargs = config_kwargs.copy()
        # We need to stringify the Selectable object to make its hash deterministic

        # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html
        sql = config_kwargs["sql"]
        if not isinstance(sql, str):
            if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules:
                import sqlalchemy

                if isinstance(sql, sqlalchemy.sql.Selectable):
                    engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://")
                    sql_str = str(sql.compile(dialect=engine.dialect))
                    config_kwargs["sql"] = sql_str
                else:
                    raise TypeError(
                        f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}"
                    )
            else:
                raise TypeError(
                    f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}"
                )
        con = config_kwargs["con"]
        if not isinstance(con, str):
            config_kwargs["con"] = id(con)
            logger.info(
                f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead."
            )

        return super().create_config_id(config_kwargs, custom_features=custom_features)

    @property
    def pd_read_sql_kwargs(self):
        pd_read_sql_kwargs = {
            "index_col": self.index_col,
            "columns": self.columns,
            "params": self.params,
            "coerce_float": self.coerce_float,
            "parse_dates": self.parse_dates,
        }
        return pd_read_sql_kwargs


class Sql(datasets.ArrowBasedBuilder):
    BUILDER_CONFIG_CLASS = SqlConfig

    def _info(self):
        return datasets.DatasetInfo(features=self.config.features)

    def _split_generators(self, dl_manager):
        return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})]

    def _cast_table(self, pa_table: pa.Table) -> pa.Table:
        if self.config.features is not None:
            schema = self.config.features.arrow_schema
            if all(not require_storage_cast(feature) for feature in self.config.features.values()):
                # cheaper cast
                pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
            else:
                # more expensive cast; allows str <-> int/float or str to Audio for example
                pa_table = table_cast(pa_table, schema)
        return pa_table

    def _generate_tables(self):
        chunksize = self.config.chunksize
        sql_reader = pd.read_sql(
            self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs
        )
        sql_reader = [sql_reader] if chunksize is None else sql_reader
        for chunk_idx, df in enumerate(sql_reader):
            pa_table = pa.Table.from_pandas(df)
            yield chunk_idx, self._cast_table(pa_table)