#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import string
from typing import Any, Dict, Optional, Union, List, Sequence, Mapping, Tuple
import uuid
import warnings
import pandas as pd
from pyspark.pandas.internal import InternalFrame
from pyspark.pandas.namespace import _get_index_map
from pyspark import pandas as ps
from pyspark.sql import SparkSession
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
__all__ = ["sql"]
# This is not used in this file. It's for legacy sql_processor.
_CAPTURE_SCOPES = 3
[docs]def sql(
    query: str,
    index_col: Optional[Union[str, List[str]]] = None,
    args: Optional[Union[Dict[str, Any], List]] = None,
    **kwargs: Any,
) -> DataFrame:
    """
    Execute a SQL query and return the result as a pandas-on-Spark DataFrame.
    This function acts as a standard Python string formatter with understanding
    the following variable types:
        * pandas-on-Spark DataFrame
        * pandas-on-Spark Series
        * pandas DataFrame
        * pandas Series
        * string
    Also the method can bind named parameters to SQL literals from `args`.
    Parameters
    ----------
    query : str
        the SQL query
    index_col : str or list of str, optional
        Column names to be used in Spark to represent pandas-on-Spark's index. The index name
        in pandas-on-Spark is ignored. By default, the index is always lost.
        .. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`,
            and pass it to the SQL statement with `index_col` parameter.
            For example,
            >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
            >>> new_psdf = psdf.reset_index()
            >>> ps.sql("SELECT * FROM {new_psdf}", index_col="index", new_psdf=new_psdf)
            ... # doctest: +NORMALIZE_WHITESPACE
                   A  B
            index
            a      1  4
            b      2  5
            c      3  6
            For MultiIndex,
            >>> psdf = ps.DataFrame(
            ...     {"A": [1, 2, 3], "B": [4, 5, 6]},
            ...     index=pd.MultiIndex.from_tuples(
            ...         [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
            ...     ),
            ... )
            >>> new_psdf = psdf.reset_index()
            >>> ps.sql(
            ...     "SELECT * FROM {new_psdf}", index_col=["index1", "index2"], new_psdf=new_psdf)
            ... # doctest: +NORMALIZE_WHITESPACE
                           A  B
            index1 index2
            a      b       1  4
            c      d       2  5
            e      f       3  6
            Also note that the index name(s) should be matched to the existing name.
    args : dict or list
        A dictionary of parameter names to Python objects or a list of Python objects
        that can be converted to SQL literal expressions. See
        <a href="https://spark.apache.org/docs/latest/sql-ref-datatypes.html">
        Supported Data Types</a> for supported value types in Python.
        For example, dictionary keys: "rank", "name", "birthdate";
        dictionary values: 1, "Steven", datetime.date(2023, 4, 2).
        A value can be also a `Column` of literal expression, in that case it is taken as is.
        .. versionadded:: 3.4.0
        .. versionchanged:: 3.5.0
            Added positional parameters.
    kwargs
        other variables that the user want to set that can be referenced in the query
    Returns
    -------
    pandas-on-Spark DataFrame
    Examples
    --------
    Calling a built-in SQL function.
    >>> ps.sql("SELECT * FROM range(10) where id > 7")
       id
    0   8
    1   9
    >>> ps.sql("SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9)
       id
    0   8
    >>> mydf = ps.range(10)
    >>> x = tuple(range(4))
    >>> ps.sql("SELECT {ser} FROM {mydf} WHERE id IN {x}", ser=mydf.id, mydf=mydf, x=x)
       id
    0   0
    1   1
    2   2
    3   3
    Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is
    dropped.
    >>> ps.sql('''
    ...   SELECT m1.a, m2.b
    ...   FROM {table1} m1 INNER JOIN {table2} m2
    ...   ON m1.key = m2.key
    ...   ORDER BY m1.a, m2.b''',
    ...   table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}),
    ...   table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]}))
       a  b
    0  1  3
    1  2  4
    2  2  5
    Also, it is possible to query using Series.
    >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
    >>> ps.sql("SELECT {mydf.A} FROM {mydf}", mydf=psdf)
       A
    0  1
    1  2
    2  3
    And substitude named parameters with the `:` prefix by SQL literals.
    >>> ps.sql("SELECT * FROM range(10) WHERE id > :bound1", args={"bound1":7})
       id
    0   8
    1   9
    Or positional parameters marked by `?` in the SQL query by SQL literals.
    >>> ps.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
       id
    0   8
    1   9
    """
    if os.environ.get("PYSPARK_PANDAS_SQL_LEGACY") == "1":
        from pyspark.pandas import sql_processor
        warnings.warn(
            "Deprecated in 3.3.0, and the legacy behavior "
            "will be removed in the future releases.",
            FutureWarning,
        )
        return sql_processor.sql(query, index_col=index_col, **kwargs)
    session = default_session()
    formatter = PandasSQLStringFormatter(session)
    try:
        sdf = session.sql(formatter.format(query, **kwargs), args)
    finally:
        formatter.clear()
    index_spark_columns, index_names = _get_index_map(sdf, index_col)
    return DataFrame(
        InternalFrame(
            spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names
        )
    ) 
class PandasSQLStringFormatter(string.Formatter):
    """
    A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances
    with basic Python objects. This object must be clear after the use for single SQL
    query; cannot be reused across multiple SQL queries without cleaning.
    """
    def __init__(self, session: SparkSession) -> None:
        self._session: SparkSession = session
        self._temp_views: List[Tuple[DataFrame, str]] = []
        self._ref_sers: List[Tuple[Series, str]] = []
    def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
        ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs)
        for ref, n in self._ref_sers:
            if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views):
                # If referred DataFrame does not hold the given Series, raise an error.
                raise ValueError("The series in {%s} does not refer any dataframe specified." % n)
        return ret
    def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
        obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs)
        return self._convert_value(obj, field_name), first
    def _convert_value(self, val: Any, name: str) -> Optional[str]:
        """
        Converts the given value into a SQL string.
        """
        if isinstance(val, pd.Series):
            # Return the column name from pandas Series directly.
            return ps.from_pandas(val).to_frame()._to_spark().columns[0]
        elif isinstance(val, Series):
            # Return the column name of pandas-on-Spark Series iff its DataFrame was
            # referred. The check will be done in `vformat` after we parse all.
            self._ref_sers.append((val, name))
            return val.to_frame()._to_spark().columns[0]
        elif isinstance(val, (DataFrame, pd.DataFrame)):
            df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "")
            if isinstance(val, pd.DataFrame):
                # Don't store temp view for plain pandas instances
                # because it is unable to know which pandas DataFrame
                # holds which Series.
                val = ps.from_pandas(val)
            else:
                for df, n in self._temp_views:
                    if df is val:
                        return n
                self._temp_views.append((val, df_name))
            val._to_spark().createOrReplaceTempView(df_name)
            return df_name
        elif isinstance(val, str):
            # This is matched to behavior from JVM implementation.
            # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/
            # sql/catalyst/expressions/literals.scala`
            return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
        else:
            return val
    def clear(self) -> None:
        for _, n in self._temp_views:
            self._session.catalog.dropTempView(n)
        self._temp_views = []
        self._ref_sers = []
def _test() -> None:
    import os
    import doctest
    import sys
    from pyspark.sql import SparkSession
    import pyspark.pandas.sql_formatter
    os.chdir(os.environ["SPARK_HOME"])
    globs = pyspark.pandas.sql_formatter.__dict__.copy()
    globs["ps"] = pyspark.pandas
    spark = (
        SparkSession.builder.master("local[4]")
        .appName("pyspark.pandas.sql_formatter tests")
        .getOrCreate()
    )
    (failure_count, test_count) = doctest.testmod(
        pyspark.pandas.sql_formatter,
        globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
    )
    spark.stop()
    if failure_count:
        sys.exit(-1)
if __name__ == "__main__":
    _test()