Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 165            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 166            for modifier in ("cluster", "distribute", "sort"):
 167                modifier_transforms.pop(modifier, None)
 168
 169            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 170
 171        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 172            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 173                TokenType.ANTI,
 174                TokenType.SEMI,
 175            }
 176
 177        return klass
 178
 179
 180class Dialect(metaclass=_Dialect):
 181    INDEX_OFFSET = 0
 182    """The base index offset for arrays."""
 183
 184    WEEK_OFFSET = 0
 185    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 186
 187    UNNEST_COLUMN_ONLY = False
 188    """Whether `UNNEST` table aliases are treated as column aliases."""
 189
 190    ALIAS_POST_TABLESAMPLE = False
 191    """Whether the table alias comes after tablesample."""
 192
 193    TABLESAMPLE_SIZE_IS_PERCENT = False
 194    """Whether a size in the table sample clause represents percentage."""
 195
 196    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 197    """Specifies the strategy according to which identifiers should be normalized."""
 198
 199    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 200    """Whether an unquoted identifier can start with a digit."""
 201
 202    DPIPE_IS_STRING_CONCAT = True
 203    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 204
 205    STRICT_STRING_CONCAT = False
 206    """Whether `CONCAT`'s arguments must be strings."""
 207
 208    SUPPORTS_USER_DEFINED_TYPES = True
 209    """Whether user-defined data types are supported."""
 210
 211    SUPPORTS_SEMI_ANTI_JOIN = True
 212    """Whether `SEMI` or `ANTI` joins are supported."""
 213
 214    NORMALIZE_FUNCTIONS: bool | str = "upper"
 215    """
 216    Determines how function names are going to be normalized.
 217    Possible values:
 218        "upper" or True: Convert names to uppercase.
 219        "lower": Convert names to lowercase.
 220        False: Disables function name normalization.
 221    """
 222
 223    LOG_BASE_FIRST: t.Optional[bool] = True
 224    """
 225    Whether the base comes first in the `LOG` function.
 226    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 227    """
 228
 229    NULL_ORDERING = "nulls_are_small"
 230    """
 231    Default `NULL` ordering method to use if not explicitly set.
 232    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 233    """
 234
 235    TYPED_DIVISION = False
 236    """
 237    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 238    False means `a / b` is always float division.
 239    True means `a / b` is integer division if both `a` and `b` are integers.
 240    """
 241
 242    SAFE_DIVISION = False
 243    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 244
 245    CONCAT_COALESCE = False
 246    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 247
 248    DATE_FORMAT = "'%Y-%m-%d'"
 249    DATEINT_FORMAT = "'%Y%m%d'"
 250    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 251
 252    TIME_MAPPING: t.Dict[str, str] = {}
 253    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 254
 255    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 256    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 257    FORMAT_MAPPING: t.Dict[str, str] = {}
 258    """
 259    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 260    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 261    """
 262
 263    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 264    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 265
 266    PSEUDOCOLUMNS: t.Set[str] = set()
 267    """
 268    Columns that are auto-generated by the engine corresponding to this dialect.
 269    For example, such columns may be excluded from `SELECT *` queries.
 270    """
 271
 272    PREFER_CTE_ALIAS_COLUMN = False
 273    """
 274    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 275    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 276    any projection aliases in the subquery.
 277
 278    For example,
 279        WITH y(c) AS (
 280            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 281        ) SELECT c FROM y;
 282
 283        will be rewritten as
 284
 285        WITH y(c) AS (
 286            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 287        ) SELECT c FROM y;
 288    """
 289
 290    # --- Autofilled ---
 291
 292    tokenizer_class = Tokenizer
 293    parser_class = Parser
 294    generator_class = Generator
 295
 296    # A trie of the time_mapping keys
 297    TIME_TRIE: t.Dict = {}
 298    FORMAT_TRIE: t.Dict = {}
 299
 300    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 301    INVERSE_TIME_TRIE: t.Dict = {}
 302
 303    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 304
 305    # Delimiters for string literals and identifiers
 306    QUOTE_START = "'"
 307    QUOTE_END = "'"
 308    IDENTIFIER_START = '"'
 309    IDENTIFIER_END = '"'
 310
 311    # Delimiters for bit, hex, byte and unicode literals
 312    BIT_START: t.Optional[str] = None
 313    BIT_END: t.Optional[str] = None
 314    HEX_START: t.Optional[str] = None
 315    HEX_END: t.Optional[str] = None
 316    BYTE_START: t.Optional[str] = None
 317    BYTE_END: t.Optional[str] = None
 318    UNICODE_START: t.Optional[str] = None
 319    UNICODE_END: t.Optional[str] = None
 320
 321    # Separator of COPY statement parameters
 322    COPY_PARAMS_ARE_CSV = True
 323
 324    @classmethod
 325    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 326        """
 327        Look up a dialect in the global dialect registry and return it if it exists.
 328
 329        Args:
 330            dialect: The target dialect. If this is a string, it can be optionally followed by
 331                additional key-value pairs that are separated by commas and are used to specify
 332                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 333
 334        Example:
 335            >>> dialect = dialect_class = get_or_raise("duckdb")
 336            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 337
 338        Returns:
 339            The corresponding Dialect instance.
 340        """
 341
 342        if not dialect:
 343            return cls()
 344        if isinstance(dialect, _Dialect):
 345            return dialect()
 346        if isinstance(dialect, Dialect):
 347            return dialect
 348        if isinstance(dialect, str):
 349            try:
 350                dialect_name, *kv_pairs = dialect.split(",")
 351                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 352            except ValueError:
 353                raise ValueError(
 354                    f"Invalid dialect format: '{dialect}'. "
 355                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 356                )
 357
 358            result = cls.get(dialect_name.strip())
 359            if not result:
 360                from difflib import get_close_matches
 361
 362                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 363                if similar:
 364                    similar = f" Did you mean {similar}?"
 365
 366                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 367
 368            return result(**kwargs)
 369
 370        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 371
 372    @classmethod
 373    def format_time(
 374        cls, expression: t.Optional[str | exp.Expression]
 375    ) -> t.Optional[exp.Expression]:
 376        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 377        if isinstance(expression, str):
 378            return exp.Literal.string(
 379                # the time formats are quoted
 380                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 381            )
 382
 383        if expression and expression.is_string:
 384            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 385
 386        return expression
 387
 388    def __init__(self, **kwargs) -> None:
 389        normalization_strategy = kwargs.get("normalization_strategy")
 390
 391        if normalization_strategy is None:
 392            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 393        else:
 394            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 395
 396    def __eq__(self, other: t.Any) -> bool:
 397        # Does not currently take dialect state into account
 398        return type(self) == other
 399
 400    def __hash__(self) -> int:
 401        # Does not currently take dialect state into account
 402        return hash(type(self))
 403
 404    def normalize_identifier(self, expression: E) -> E:
 405        """
 406        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 407
 408        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 409        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 410        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 411        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 412
 413        There are also dialects like Spark, which are case-insensitive even when quotes are
 414        present, and dialects like MySQL, whose resolution rules match those employed by the
 415        underlying operating system, for example they may always be case-sensitive in Linux.
 416
 417        Finally, the normalization behavior of some engines can even be controlled through flags,
 418        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 419
 420        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 421        that it can analyze queries in the optimizer and successfully capture their semantics.
 422        """
 423        if (
 424            isinstance(expression, exp.Identifier)
 425            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 426            and (
 427                not expression.quoted
 428                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 429            )
 430        ):
 431            expression.set(
 432                "this",
 433                (
 434                    expression.this.upper()
 435                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 436                    else expression.this.lower()
 437                ),
 438            )
 439
 440        return expression
 441
 442    def case_sensitive(self, text: str) -> bool:
 443        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 444        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 445            return False
 446
 447        unsafe = (
 448            str.islower
 449            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 450            else str.isupper
 451        )
 452        return any(unsafe(char) for char in text)
 453
 454    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 455        """Checks if text can be identified given an identify option.
 456
 457        Args:
 458            text: The text to check.
 459            identify:
 460                `"always"` or `True`: Always returns `True`.
 461                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 462
 463        Returns:
 464            Whether the given text can be identified.
 465        """
 466        if identify is True or identify == "always":
 467            return True
 468
 469        if identify == "safe":
 470            return not self.case_sensitive(text)
 471
 472        return False
 473
 474    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 475        """
 476        Adds quotes to a given identifier.
 477
 478        Args:
 479            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 480            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 481                "unsafe", with respect to its characters and this dialect's normalization strategy.
 482        """
 483        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 484            name = expression.this
 485            expression.set(
 486                "quoted",
 487                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 488            )
 489
 490        return expression
 491
 492    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 493        if isinstance(path, exp.Literal):
 494            path_text = path.name
 495            if path.is_number:
 496                path_text = f"[{path_text}]"
 497
 498            try:
 499                return parse_json_path(path_text)
 500            except ParseError as e:
 501                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 502
 503        return path
 504
 505    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 506        return self.parser(**opts).parse(self.tokenize(sql), sql)
 507
 508    def parse_into(
 509        self, expression_type: exp.IntoType, sql: str, **opts
 510    ) -> t.List[t.Optional[exp.Expression]]:
 511        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 512
 513    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 514        return self.generator(**opts).generate(expression, copy=copy)
 515
 516    def transpile(self, sql: str, **opts) -> t.List[str]:
 517        return [
 518            self.generate(expression, copy=False, **opts) if expression else ""
 519            for expression in self.parse(sql)
 520        ]
 521
 522    def tokenize(self, sql: str) -> t.List[Token]:
 523        return self.tokenizer.tokenize(sql)
 524
 525    @property
 526    def tokenizer(self) -> Tokenizer:
 527        if not hasattr(self, "_tokenizer"):
 528            self._tokenizer = self.tokenizer_class(dialect=self)
 529        return self._tokenizer
 530
 531    def parser(self, **opts) -> Parser:
 532        return self.parser_class(dialect=self, **opts)
 533
 534    def generator(self, **opts) -> Generator:
 535        return self.generator_class(dialect=self, **opts)
 536
 537
 538DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 539
 540
 541def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 542    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 543
 544
 545def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 546    if expression.args.get("accuracy"):
 547        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 548    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 549
 550
 551def if_sql(
 552    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 553) -> t.Callable[[Generator, exp.If], str]:
 554    def _if_sql(self: Generator, expression: exp.If) -> str:
 555        return self.func(
 556            name,
 557            expression.this,
 558            expression.args.get("true"),
 559            expression.args.get("false") or false_value,
 560        )
 561
 562    return _if_sql
 563
 564
 565def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 566    this = expression.this
 567    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 568        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 569
 570    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 571
 572
 573def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 574    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 575
 576
 577def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 578    elem = seq_get(expression.expressions, 0)
 579    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 580        return self.func("ARRAY", elem)
 581    return inline_array_sql(self, expression)
 582
 583
 584def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 585    return self.like_sql(
 586        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 587    )
 588
 589
 590def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 591    zone = self.sql(expression, "this")
 592    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 593
 594
 595def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 596    if expression.args.get("recursive"):
 597        self.unsupported("Recursive CTEs are unsupported")
 598        expression.args["recursive"] = False
 599    return self.with_sql(expression)
 600
 601
 602def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 603    n = self.sql(expression, "this")
 604    d = self.sql(expression, "expression")
 605    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 606
 607
 608def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 609    self.unsupported("TABLESAMPLE unsupported")
 610    return self.sql(expression.this)
 611
 612
 613def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 614    self.unsupported("PIVOT unsupported")
 615    return ""
 616
 617
 618def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 619    return self.cast_sql(expression)
 620
 621
 622def no_comment_column_constraint_sql(
 623    self: Generator, expression: exp.CommentColumnConstraint
 624) -> str:
 625    self.unsupported("CommentColumnConstraint unsupported")
 626    return ""
 627
 628
 629def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 630    self.unsupported("MAP_FROM_ENTRIES unsupported")
 631    return ""
 632
 633
 634def str_position_sql(
 635    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 636) -> str:
 637    this = self.sql(expression, "this")
 638    substr = self.sql(expression, "substr")
 639    position = self.sql(expression, "position")
 640    instance = expression.args.get("instance") if generate_instance else None
 641    position_offset = ""
 642
 643    if position:
 644        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 645        this = self.func("SUBSTR", this, position)
 646        position_offset = f" + {position} - 1"
 647
 648    return self.func("STRPOS", this, substr, instance) + position_offset
 649
 650
 651def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 652    return (
 653        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 654    )
 655
 656
 657def var_map_sql(
 658    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 659) -> str:
 660    keys = expression.args["keys"]
 661    values = expression.args["values"]
 662
 663    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 664        self.unsupported("Cannot convert array columns into map.")
 665        return self.func(map_func_name, keys, values)
 666
 667    args = []
 668    for key, value in zip(keys.expressions, values.expressions):
 669        args.append(self.sql(key))
 670        args.append(self.sql(value))
 671
 672    return self.func(map_func_name, *args)
 673
 674
 675def build_formatted_time(
 676    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 677) -> t.Callable[[t.List], E]:
 678    """Helper used for time expressions.
 679
 680    Args:
 681        exp_class: the expression class to instantiate.
 682        dialect: target sql dialect.
 683        default: the default format, True being time.
 684
 685    Returns:
 686        A callable that can be used to return the appropriately formatted time expression.
 687    """
 688
 689    def _builder(args: t.List):
 690        return exp_class(
 691            this=seq_get(args, 0),
 692            format=Dialect[dialect].format_time(
 693                seq_get(args, 1)
 694                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 695            ),
 696        )
 697
 698    return _builder
 699
 700
 701def time_format(
 702    dialect: DialectType = None,
 703) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 704    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 705        """
 706        Returns the time format for a given expression, unless it's equivalent
 707        to the default time format of the dialect of interest.
 708        """
 709        time_format = self.format_time(expression)
 710        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 711
 712    return _time_format
 713
 714
 715def build_date_delta(
 716    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 717) -> t.Callable[[t.List], E]:
 718    def _builder(args: t.List) -> E:
 719        unit_based = len(args) == 3
 720        this = args[2] if unit_based else seq_get(args, 0)
 721        unit = args[0] if unit_based else exp.Literal.string("DAY")
 722        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 723        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 724
 725    return _builder
 726
 727
 728def build_date_delta_with_interval(
 729    expression_class: t.Type[E],
 730) -> t.Callable[[t.List], t.Optional[E]]:
 731    def _builder(args: t.List) -> t.Optional[E]:
 732        if len(args) < 2:
 733            return None
 734
 735        interval = args[1]
 736
 737        if not isinstance(interval, exp.Interval):
 738            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 739
 740        expression = interval.this
 741        if expression and expression.is_string:
 742            expression = exp.Literal.number(expression.this)
 743
 744        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 745
 746    return _builder
 747
 748
 749def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 750    unit = seq_get(args, 0)
 751    this = seq_get(args, 1)
 752
 753    if isinstance(this, exp.Cast) and this.is_type("date"):
 754        return exp.DateTrunc(unit=unit, this=this)
 755    return exp.TimestampTrunc(this=this, unit=unit)
 756
 757
 758def date_add_interval_sql(
 759    data_type: str, kind: str
 760) -> t.Callable[[Generator, exp.Expression], str]:
 761    def func(self: Generator, expression: exp.Expression) -> str:
 762        this = self.sql(expression, "this")
 763        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 764        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 765
 766    return func
 767
 768
 769def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 770    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
 771
 772
 773def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 774    if not expression.expression:
 775        from sqlglot.optimizer.annotate_types import annotate_types
 776
 777        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 778        return self.sql(exp.cast(expression.this, target_type))
 779    if expression.text("expression").lower() in TIMEZONES:
 780        return self.sql(
 781            exp.AtTimeZone(
 782                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 783                zone=expression.expression,
 784            )
 785        )
 786    return self.func("TIMESTAMP", expression.this, expression.expression)
 787
 788
 789def locate_to_strposition(args: t.List) -> exp.Expression:
 790    return exp.StrPosition(
 791        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 792    )
 793
 794
 795def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 796    return self.func(
 797        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 798    )
 799
 800
 801def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 802    return self.sql(
 803        exp.Substring(
 804            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 805        )
 806    )
 807
 808
 809def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 810    return self.sql(
 811        exp.Substring(
 812            this=expression.this,
 813            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 814        )
 815    )
 816
 817
 818def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 819    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 820
 821
 822def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 823    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 824
 825
 826# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 827def encode_decode_sql(
 828    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 829) -> str:
 830    charset = expression.args.get("charset")
 831    if charset and charset.name.lower() != "utf-8":
 832        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 833
 834    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 835
 836
 837def min_or_least(self: Generator, expression: exp.Min) -> str:
 838    name = "LEAST" if expression.expressions else "MIN"
 839    return rename_func(name)(self, expression)
 840
 841
 842def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 843    name = "GREATEST" if expression.expressions else "MAX"
 844    return rename_func(name)(self, expression)
 845
 846
 847def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 848    cond = expression.this
 849
 850    if isinstance(expression.this, exp.Distinct):
 851        cond = expression.this.expressions[0]
 852        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 853
 854    return self.func("sum", exp.func("if", cond, 1, 0))
 855
 856
 857def trim_sql(self: Generator, expression: exp.Trim) -> str:
 858    target = self.sql(expression, "this")
 859    trim_type = self.sql(expression, "position")
 860    remove_chars = self.sql(expression, "expression")
 861    collation = self.sql(expression, "collation")
 862
 863    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 864    if not remove_chars and not collation:
 865        return self.trim_sql(expression)
 866
 867    trim_type = f"{trim_type} " if trim_type else ""
 868    remove_chars = f"{remove_chars} " if remove_chars else ""
 869    from_part = "FROM " if trim_type or remove_chars else ""
 870    collation = f" COLLATE {collation}" if collation else ""
 871    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 872
 873
 874def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 875    return self.func("STRPTIME", expression.this, self.format_time(expression))
 876
 877
 878def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 879    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 880
 881
 882def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 883    delim, *rest_args = expression.expressions
 884    return self.sql(
 885        reduce(
 886            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 887            rest_args,
 888        )
 889    )
 890
 891
 892def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 893    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 894    if bad_args:
 895        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 896
 897    return self.func(
 898        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 899    )
 900
 901
 902def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 903    bad_args = list(
 904        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 905    )
 906    if bad_args:
 907        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 908
 909    return self.func(
 910        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 911    )
 912
 913
 914def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 915    names = []
 916    for agg in aggregations:
 917        if isinstance(agg, exp.Alias):
 918            names.append(agg.alias)
 919        else:
 920            """
 921            This case corresponds to aggregations without aliases being used as suffixes
 922            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 923            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 924            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 925            """
 926            agg_all_unquoted = agg.transform(
 927                lambda node: (
 928                    exp.Identifier(this=node.name, quoted=False)
 929                    if isinstance(node, exp.Identifier)
 930                    else node
 931                )
 932            )
 933            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 934
 935    return names
 936
 937
 938def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 939    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 940
 941
 942# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 943def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 944    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 945
 946
 947def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 948    return self.func("MAX", expression.this)
 949
 950
 951def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 952    a = self.sql(expression.left)
 953    b = self.sql(expression.right)
 954    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 955
 956
 957def is_parse_json(expression: exp.Expression) -> bool:
 958    return isinstance(expression, exp.ParseJSON) or (
 959        isinstance(expression, exp.Cast) and expression.is_type("json")
 960    )
 961
 962
 963def isnull_to_is_null(args: t.List) -> exp.Expression:
 964    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 965
 966
 967def generatedasidentitycolumnconstraint_sql(
 968    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 969) -> str:
 970    start = self.sql(expression, "start") or "1"
 971    increment = self.sql(expression, "increment") or "1"
 972    return f"IDENTITY({start}, {increment})"
 973
 974
 975def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 976    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 977        if expression.args.get("count"):
 978            self.unsupported(f"Only two arguments are supported in function {name}.")
 979
 980        return self.func(name, expression.this, expression.expression)
 981
 982    return _arg_max_or_min_sql
 983
 984
 985def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 986    this = expression.this.copy()
 987
 988    return_type = expression.return_type
 989    if return_type.is_type(exp.DataType.Type.DATE):
 990        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 991        # can truncate timestamp strings, because some dialects can't cast them to DATE
 992        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 993
 994    expression.this.replace(exp.cast(this, return_type))
 995    return expression
 996
 997
 998def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 999    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1000        if cast and isinstance(expression, exp.TsOrDsAdd):
1001            expression = ts_or_ds_add_cast(expression)
1002
1003        return self.func(
1004            name,
1005            unit_to_var(expression),
1006            expression.expression,
1007            expression.this,
1008        )
1009
1010    return _delta_sql
1011
1012
1013def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1014    unit = expression.args.get("unit")
1015
1016    if isinstance(unit, exp.Placeholder):
1017        return unit
1018    if unit:
1019        return exp.Literal.string(unit.name)
1020    return exp.Literal.string(default) if default else None
1021
1022
1023def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1024    unit = expression.args.get("unit")
1025
1026    if isinstance(unit, (exp.Var, exp.Placeholder)):
1027        return unit
1028    return exp.Var(this=default) if default else None
1029
1030
1031def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1032    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1033    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1034    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1035
1036    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1037
1038
1039def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1040    """Remove table refs from columns in when statements."""
1041    alias = expression.this.args.get("alias")
1042
1043    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1044        return self.dialect.normalize_identifier(identifier).name if identifier else None
1045
1046    targets = {normalize(expression.this.this)}
1047
1048    if alias:
1049        targets.add(normalize(alias.this))
1050
1051    for when in expression.expressions:
1052        when.transform(
1053            lambda node: (
1054                exp.column(node.this)
1055                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1056                else node
1057            ),
1058            copy=False,
1059        )
1060
1061    return self.merge_sql(expression)
1062
1063
1064def build_json_extract_path(
1065    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1066) -> t.Callable[[t.List], F]:
1067    def _builder(args: t.List) -> F:
1068        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1069        for arg in args[1:]:
1070            if not isinstance(arg, exp.Literal):
1071                # We use the fallback parser because we can't really transpile non-literals safely
1072                return expr_type.from_arg_list(args)
1073
1074            text = arg.name
1075            if is_int(text):
1076                index = int(text)
1077                segments.append(
1078                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1079                )
1080            else:
1081                segments.append(exp.JSONPathKey(this=text))
1082
1083        # This is done to avoid failing in the expression validator due to the arg count
1084        del args[2:]
1085        return expr_type(
1086            this=seq_get(args, 0),
1087            expression=exp.JSONPath(expressions=segments),
1088            only_json_types=arrow_req_json_type,
1089        )
1090
1091    return _builder
1092
1093
1094def json_extract_segments(
1095    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1096) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1097    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1098        path = expression.expression
1099        if not isinstance(path, exp.JSONPath):
1100            return rename_func(name)(self, expression)
1101
1102        segments = []
1103        for segment in path.expressions:
1104            path = self.sql(segment)
1105            if path:
1106                if isinstance(segment, exp.JSONPathPart) and (
1107                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1108                ):
1109                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1110
1111                segments.append(path)
1112
1113        if op:
1114            return f" {op} ".join([self.sql(expression.this), *segments])
1115        return self.func(name, expression.this, *segments)
1116
1117    return _json_extract_segments
1118
1119
1120def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1121    if isinstance(expression.this, exp.JSONPathWildcard):
1122        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1123
1124    return expression.name
1125
1126
1127def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1128    cond = expression.expression
1129    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1130        alias = cond.expressions[0]
1131        cond = cond.this
1132    elif isinstance(cond, exp.Predicate):
1133        alias = "_u"
1134    else:
1135        self.unsupported("Unsupported filter condition")
1136        return ""
1137
1138    unnest = exp.Unnest(expressions=[expression.this])
1139    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1140    return self.sql(exp.Array(expressions=[filtered]))
1141
1142
1143def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1144    return self.func(
1145        "TO_NUMBER",
1146        expression.this,
1147        expression.args.get("format"),
1148        expression.args.get("nlsparam"),
1149    )
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
181class Dialect(metaclass=_Dialect):
182    INDEX_OFFSET = 0
183    """The base index offset for arrays."""
184
185    WEEK_OFFSET = 0
186    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
187
188    UNNEST_COLUMN_ONLY = False
189    """Whether `UNNEST` table aliases are treated as column aliases."""
190
191    ALIAS_POST_TABLESAMPLE = False
192    """Whether the table alias comes after tablesample."""
193
194    TABLESAMPLE_SIZE_IS_PERCENT = False
195    """Whether a size in the table sample clause represents percentage."""
196
197    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
198    """Specifies the strategy according to which identifiers should be normalized."""
199
200    IDENTIFIERS_CAN_START_WITH_DIGIT = False
201    """Whether an unquoted identifier can start with a digit."""
202
203    DPIPE_IS_STRING_CONCAT = True
204    """Whether the DPIPE token (`||`) is a string concatenation operator."""
205
206    STRICT_STRING_CONCAT = False
207    """Whether `CONCAT`'s arguments must be strings."""
208
209    SUPPORTS_USER_DEFINED_TYPES = True
210    """Whether user-defined data types are supported."""
211
212    SUPPORTS_SEMI_ANTI_JOIN = True
213    """Whether `SEMI` or `ANTI` joins are supported."""
214
215    NORMALIZE_FUNCTIONS: bool | str = "upper"
216    """
217    Determines how function names are going to be normalized.
218    Possible values:
219        "upper" or True: Convert names to uppercase.
220        "lower": Convert names to lowercase.
221        False: Disables function name normalization.
222    """
223
224    LOG_BASE_FIRST: t.Optional[bool] = True
225    """
226    Whether the base comes first in the `LOG` function.
227    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
228    """
229
230    NULL_ORDERING = "nulls_are_small"
231    """
232    Default `NULL` ordering method to use if not explicitly set.
233    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
234    """
235
236    TYPED_DIVISION = False
237    """
238    Whether the behavior of `a / b` depends on the types of `a` and `b`.
239    False means `a / b` is always float division.
240    True means `a / b` is integer division if both `a` and `b` are integers.
241    """
242
243    SAFE_DIVISION = False
244    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
245
246    CONCAT_COALESCE = False
247    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
248
249    DATE_FORMAT = "'%Y-%m-%d'"
250    DATEINT_FORMAT = "'%Y%m%d'"
251    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
252
253    TIME_MAPPING: t.Dict[str, str] = {}
254    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
255
256    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
257    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
258    FORMAT_MAPPING: t.Dict[str, str] = {}
259    """
260    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
261    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
262    """
263
264    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
265    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
266
267    PSEUDOCOLUMNS: t.Set[str] = set()
268    """
269    Columns that are auto-generated by the engine corresponding to this dialect.
270    For example, such columns may be excluded from `SELECT *` queries.
271    """
272
273    PREFER_CTE_ALIAS_COLUMN = False
274    """
275    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
276    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
277    any projection aliases in the subquery.
278
279    For example,
280        WITH y(c) AS (
281            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
282        ) SELECT c FROM y;
283
284        will be rewritten as
285
286        WITH y(c) AS (
287            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
288        ) SELECT c FROM y;
289    """
290
291    # --- Autofilled ---
292
293    tokenizer_class = Tokenizer
294    parser_class = Parser
295    generator_class = Generator
296
297    # A trie of the time_mapping keys
298    TIME_TRIE: t.Dict = {}
299    FORMAT_TRIE: t.Dict = {}
300
301    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
302    INVERSE_TIME_TRIE: t.Dict = {}
303
304    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
305
306    # Delimiters for string literals and identifiers
307    QUOTE_START = "'"
308    QUOTE_END = "'"
309    IDENTIFIER_START = '"'
310    IDENTIFIER_END = '"'
311
312    # Delimiters for bit, hex, byte and unicode literals
313    BIT_START: t.Optional[str] = None
314    BIT_END: t.Optional[str] = None
315    HEX_START: t.Optional[str] = None
316    HEX_END: t.Optional[str] = None
317    BYTE_START: t.Optional[str] = None
318    BYTE_END: t.Optional[str] = None
319    UNICODE_START: t.Optional[str] = None
320    UNICODE_END: t.Optional[str] = None
321
322    # Separator of COPY statement parameters
323    COPY_PARAMS_ARE_CSV = True
324
325    @classmethod
326    def get_or_raise(cls, dialect: DialectType) -> Dialect:
327        """
328        Look up a dialect in the global dialect registry and return it if it exists.
329
330        Args:
331            dialect: The target dialect. If this is a string, it can be optionally followed by
332                additional key-value pairs that are separated by commas and are used to specify
333                dialect settings, such as whether the dialect's identifiers are case-sensitive.
334
335        Example:
336            >>> dialect = dialect_class = get_or_raise("duckdb")
337            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
338
339        Returns:
340            The corresponding Dialect instance.
341        """
342
343        if not dialect:
344            return cls()
345        if isinstance(dialect, _Dialect):
346            return dialect()
347        if isinstance(dialect, Dialect):
348            return dialect
349        if isinstance(dialect, str):
350            try:
351                dialect_name, *kv_pairs = dialect.split(",")
352                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
353            except ValueError:
354                raise ValueError(
355                    f"Invalid dialect format: '{dialect}'. "
356                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
357                )
358
359            result = cls.get(dialect_name.strip())
360            if not result:
361                from difflib import get_close_matches
362
363                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
364                if similar:
365                    similar = f" Did you mean {similar}?"
366
367                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
368
369            return result(**kwargs)
370
371        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
372
373    @classmethod
374    def format_time(
375        cls, expression: t.Optional[str | exp.Expression]
376    ) -> t.Optional[exp.Expression]:
377        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
378        if isinstance(expression, str):
379            return exp.Literal.string(
380                # the time formats are quoted
381                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
382            )
383
384        if expression and expression.is_string:
385            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
386
387        return expression
388
389    def __init__(self, **kwargs) -> None:
390        normalization_strategy = kwargs.get("normalization_strategy")
391
392        if normalization_strategy is None:
393            self.normalization_strategy = self.NORMALIZATION_STRATEGY
394        else:
395            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
396
397    def __eq__(self, other: t.Any) -> bool:
398        # Does not currently take dialect state into account
399        return type(self) == other
400
401    def __hash__(self) -> int:
402        # Does not currently take dialect state into account
403        return hash(type(self))
404
405    def normalize_identifier(self, expression: E) -> E:
406        """
407        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
408
409        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
410        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
411        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
412        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
413
414        There are also dialects like Spark, which are case-insensitive even when quotes are
415        present, and dialects like MySQL, whose resolution rules match those employed by the
416        underlying operating system, for example they may always be case-sensitive in Linux.
417
418        Finally, the normalization behavior of some engines can even be controlled through flags,
419        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
420
421        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
422        that it can analyze queries in the optimizer and successfully capture their semantics.
423        """
424        if (
425            isinstance(expression, exp.Identifier)
426            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
427            and (
428                not expression.quoted
429                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
430            )
431        ):
432            expression.set(
433                "this",
434                (
435                    expression.this.upper()
436                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
437                    else expression.this.lower()
438                ),
439            )
440
441        return expression
442
443    def case_sensitive(self, text: str) -> bool:
444        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
445        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
446            return False
447
448        unsafe = (
449            str.islower
450            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
451            else str.isupper
452        )
453        return any(unsafe(char) for char in text)
454
455    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
456        """Checks if text can be identified given an identify option.
457
458        Args:
459            text: The text to check.
460            identify:
461                `"always"` or `True`: Always returns `True`.
462                `"safe"`: Only returns `True` if the identifier is case-insensitive.
463
464        Returns:
465            Whether the given text can be identified.
466        """
467        if identify is True or identify == "always":
468            return True
469
470        if identify == "safe":
471            return not self.case_sensitive(text)
472
473        return False
474
475    def quote_identifier(self, expression: E, identify: bool = True) -> E:
476        """
477        Adds quotes to a given identifier.
478
479        Args:
480            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
481            identify: If set to `False`, the quotes will only be added if the identifier is deemed
482                "unsafe", with respect to its characters and this dialect's normalization strategy.
483        """
484        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
485            name = expression.this
486            expression.set(
487                "quoted",
488                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
489            )
490
491        return expression
492
493    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
494        if isinstance(path, exp.Literal):
495            path_text = path.name
496            if path.is_number:
497                path_text = f"[{path_text}]"
498
499            try:
500                return parse_json_path(path_text)
501            except ParseError as e:
502                logger.warning(f"Invalid JSON path syntax. {str(e)}")
503
504        return path
505
506    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
507        return self.parser(**opts).parse(self.tokenize(sql), sql)
508
509    def parse_into(
510        self, expression_type: exp.IntoType, sql: str, **opts
511    ) -> t.List[t.Optional[exp.Expression]]:
512        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
513
514    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
515        return self.generator(**opts).generate(expression, copy=copy)
516
517    def transpile(self, sql: str, **opts) -> t.List[str]:
518        return [
519            self.generate(expression, copy=False, **opts) if expression else ""
520            for expression in self.parse(sql)
521        ]
522
523    def tokenize(self, sql: str) -> t.List[Token]:
524        return self.tokenizer.tokenize(sql)
525
526    @property
527    def tokenizer(self) -> Tokenizer:
528        if not hasattr(self, "_tokenizer"):
529            self._tokenizer = self.tokenizer_class(dialect=self)
530        return self._tokenizer
531
532    def parser(self, **opts) -> Parser:
533        return self.parser_class(dialect=self, **opts)
534
535    def generator(self, **opts) -> Generator:
536        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
389    def __init__(self, **kwargs) -> None:
390        normalization_strategy = kwargs.get("normalization_strategy")
391
392        if normalization_strategy is None:
393            self.normalization_strategy = self.NORMALIZATION_STRATEGY
394        else:
395            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
325    @classmethod
326    def get_or_raise(cls, dialect: DialectType) -> Dialect:
327        """
328        Look up a dialect in the global dialect registry and return it if it exists.
329
330        Args:
331            dialect: The target dialect. If this is a string, it can be optionally followed by
332                additional key-value pairs that are separated by commas and are used to specify
333                dialect settings, such as whether the dialect's identifiers are case-sensitive.
334
335        Example:
336            >>> dialect = dialect_class = get_or_raise("duckdb")
337            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
338
339        Returns:
340            The corresponding Dialect instance.
341        """
342
343        if not dialect:
344            return cls()
345        if isinstance(dialect, _Dialect):
346            return dialect()
347        if isinstance(dialect, Dialect):
348            return dialect
349        if isinstance(dialect, str):
350            try:
351                dialect_name, *kv_pairs = dialect.split(",")
352                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
353            except ValueError:
354                raise ValueError(
355                    f"Invalid dialect format: '{dialect}'. "
356                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
357                )
358
359            result = cls.get(dialect_name.strip())
360            if not result:
361                from difflib import get_close_matches
362
363                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
364                if similar:
365                    similar = f" Did you mean {similar}?"
366
367                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
368
369            return result(**kwargs)
370
371        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
373    @classmethod
374    def format_time(
375        cls, expression: t.Optional[str | exp.Expression]
376    ) -> t.Optional[exp.Expression]:
377        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
378        if isinstance(expression, str):
379            return exp.Literal.string(
380                # the time formats are quoted
381                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
382            )
383
384        if expression and expression.is_string:
385            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
386
387        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
405    def normalize_identifier(self, expression: E) -> E:
406        """
407        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
408
409        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
410        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
411        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
412        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
413
414        There are also dialects like Spark, which are case-insensitive even when quotes are
415        present, and dialects like MySQL, whose resolution rules match those employed by the
416        underlying operating system, for example they may always be case-sensitive in Linux.
417
418        Finally, the normalization behavior of some engines can even be controlled through flags,
419        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
420
421        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
422        that it can analyze queries in the optimizer and successfully capture their semantics.
423        """
424        if (
425            isinstance(expression, exp.Identifier)
426            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
427            and (
428                not expression.quoted
429                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
430            )
431        ):
432            expression.set(
433                "this",
434                (
435                    expression.this.upper()
436                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
437                    else expression.this.lower()
438                ),
439            )
440
441        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
443    def case_sensitive(self, text: str) -> bool:
444        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
445        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
446            return False
447
448        unsafe = (
449            str.islower
450            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
451            else str.isupper
452        )
453        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
455    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
456        """Checks if text can be identified given an identify option.
457
458        Args:
459            text: The text to check.
460            identify:
461                `"always"` or `True`: Always returns `True`.
462                `"safe"`: Only returns `True` if the identifier is case-insensitive.
463
464        Returns:
465            Whether the given text can be identified.
466        """
467        if identify is True or identify == "always":
468            return True
469
470        if identify == "safe":
471            return not self.case_sensitive(text)
472
473        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
475    def quote_identifier(self, expression: E, identify: bool = True) -> E:
476        """
477        Adds quotes to a given identifier.
478
479        Args:
480            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
481            identify: If set to `False`, the quotes will only be added if the identifier is deemed
482                "unsafe", with respect to its characters and this dialect's normalization strategy.
483        """
484        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
485            name = expression.this
486            expression.set(
487                "quoted",
488                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
489            )
490
491        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
493    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
494        if isinstance(path, exp.Literal):
495            path_text = path.name
496            if path.is_number:
497                path_text = f"[{path_text}]"
498
499            try:
500                return parse_json_path(path_text)
501            except ParseError as e:
502                logger.warning(f"Invalid JSON path syntax. {str(e)}")
503
504        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
506    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
507        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
509    def parse_into(
510        self, expression_type: exp.IntoType, sql: str, **opts
511    ) -> t.List[t.Optional[exp.Expression]]:
512        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
514    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
515        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
517    def transpile(self, sql: str, **opts) -> t.List[str]:
518        return [
519            self.generate(expression, copy=False, **opts) if expression else ""
520            for expression in self.parse(sql)
521        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
523    def tokenize(self, sql: str) -> t.List[Token]:
524        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
526    @property
527    def tokenizer(self) -> Tokenizer:
528        if not hasattr(self, "_tokenizer"):
529            self._tokenizer = self.tokenizer_class(dialect=self)
530        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
532    def parser(self, **opts) -> Parser:
533        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
535    def generator(self, **opts) -> Generator:
536        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
542def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
543    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
546def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
547    if expression.args.get("accuracy"):
548        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
549    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
552def if_sql(
553    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
554) -> t.Callable[[Generator, exp.If], str]:
555    def _if_sql(self: Generator, expression: exp.If) -> str:
556        return self.func(
557            name,
558            expression.this,
559            expression.args.get("true"),
560            expression.args.get("false") or false_value,
561        )
562
563    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
566def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
567    this = expression.this
568    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
569        this.replace(exp.cast(this, exp.DataType.Type.JSON))
570
571    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
574def inline_array_sql(self: Generator, expression: exp.Array) -> str:
575    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
578def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
579    elem = seq_get(expression.expressions, 0)
580    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
581        return self.func("ARRAY", elem)
582    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
585def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
586    return self.like_sql(
587        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
588    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
591def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
592    zone = self.sql(expression, "this")
593    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
596def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
597    if expression.args.get("recursive"):
598        self.unsupported("Recursive CTEs are unsupported")
599        expression.args["recursive"] = False
600    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
603def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
604    n = self.sql(expression, "this")
605    d = self.sql(expression, "expression")
606    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
609def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
610    self.unsupported("TABLESAMPLE unsupported")
611    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
614def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
615    self.unsupported("PIVOT unsupported")
616    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
619def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
620    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
623def no_comment_column_constraint_sql(
624    self: Generator, expression: exp.CommentColumnConstraint
625) -> str:
626    self.unsupported("CommentColumnConstraint unsupported")
627    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
630def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
631    self.unsupported("MAP_FROM_ENTRIES unsupported")
632    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
635def str_position_sql(
636    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
637) -> str:
638    this = self.sql(expression, "this")
639    substr = self.sql(expression, "substr")
640    position = self.sql(expression, "position")
641    instance = expression.args.get("instance") if generate_instance else None
642    position_offset = ""
643
644    if position:
645        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
646        this = self.func("SUBSTR", this, position)
647        position_offset = f" + {position} - 1"
648
649    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
652def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
653    return (
654        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
655    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
658def var_map_sql(
659    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
660) -> str:
661    keys = expression.args["keys"]
662    values = expression.args["values"]
663
664    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
665        self.unsupported("Cannot convert array columns into map.")
666        return self.func(map_func_name, keys, values)
667
668    args = []
669    for key, value in zip(keys.expressions, values.expressions):
670        args.append(self.sql(key))
671        args.append(self.sql(value))
672
673    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
676def build_formatted_time(
677    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
678) -> t.Callable[[t.List], E]:
679    """Helper used for time expressions.
680
681    Args:
682        exp_class: the expression class to instantiate.
683        dialect: target sql dialect.
684        default: the default format, True being time.
685
686    Returns:
687        A callable that can be used to return the appropriately formatted time expression.
688    """
689
690    def _builder(args: t.List):
691        return exp_class(
692            this=seq_get(args, 0),
693            format=Dialect[dialect].format_time(
694                seq_get(args, 1)
695                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
696            ),
697        )
698
699    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
702def time_format(
703    dialect: DialectType = None,
704) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
705    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
706        """
707        Returns the time format for a given expression, unless it's equivalent
708        to the default time format of the dialect of interest.
709        """
710        time_format = self.format_time(expression)
711        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
712
713    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
716def build_date_delta(
717    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
718) -> t.Callable[[t.List], E]:
719    def _builder(args: t.List) -> E:
720        unit_based = len(args) == 3
721        this = args[2] if unit_based else seq_get(args, 0)
722        unit = args[0] if unit_based else exp.Literal.string("DAY")
723        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
724        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
725
726    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
729def build_date_delta_with_interval(
730    expression_class: t.Type[E],
731) -> t.Callable[[t.List], t.Optional[E]]:
732    def _builder(args: t.List) -> t.Optional[E]:
733        if len(args) < 2:
734            return None
735
736        interval = args[1]
737
738        if not isinstance(interval, exp.Interval):
739            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
740
741        expression = interval.this
742        if expression and expression.is_string:
743            expression = exp.Literal.number(expression.this)
744
745        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
746
747    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
750def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
751    unit = seq_get(args, 0)
752    this = seq_get(args, 1)
753
754    if isinstance(this, exp.Cast) and this.is_type("date"):
755        return exp.DateTrunc(unit=unit, this=this)
756    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
759def date_add_interval_sql(
760    data_type: str, kind: str
761) -> t.Callable[[Generator, exp.Expression], str]:
762    def func(self: Generator, expression: exp.Expression) -> str:
763        this = self.sql(expression, "this")
764        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
765        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
766
767    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
770def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
771    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
774def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
775    if not expression.expression:
776        from sqlglot.optimizer.annotate_types import annotate_types
777
778        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
779        return self.sql(exp.cast(expression.this, target_type))
780    if expression.text("expression").lower() in TIMEZONES:
781        return self.sql(
782            exp.AtTimeZone(
783                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
784                zone=expression.expression,
785            )
786        )
787    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
790def locate_to_strposition(args: t.List) -> exp.Expression:
791    return exp.StrPosition(
792        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
793    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
796def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
797    return self.func(
798        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
799    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
802def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
803    return self.sql(
804        exp.Substring(
805            this=expression.this, start=exp.Literal.number(1), length=expression.expression
806        )
807    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
810def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
811    return self.sql(
812        exp.Substring(
813            this=expression.this,
814            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
815        )
816    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
819def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
820    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
823def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
824    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
828def encode_decode_sql(
829    self: Generator, expression: exp.Expression, name: str, replace: bool = True
830) -> str:
831    charset = expression.args.get("charset")
832    if charset and charset.name.lower() != "utf-8":
833        self.unsupported(f"Expected utf-8 character set, got {charset}.")
834
835    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
838def min_or_least(self: Generator, expression: exp.Min) -> str:
839    name = "LEAST" if expression.expressions else "MIN"
840    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
843def max_or_greatest(self: Generator, expression: exp.Max) -> str:
844    name = "GREATEST" if expression.expressions else "MAX"
845    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
848def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
849    cond = expression.this
850
851    if isinstance(expression.this, exp.Distinct):
852        cond = expression.this.expressions[0]
853        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
854
855    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
858def trim_sql(self: Generator, expression: exp.Trim) -> str:
859    target = self.sql(expression, "this")
860    trim_type = self.sql(expression, "position")
861    remove_chars = self.sql(expression, "expression")
862    collation = self.sql(expression, "collation")
863
864    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
865    if not remove_chars and not collation:
866        return self.trim_sql(expression)
867
868    trim_type = f"{trim_type} " if trim_type else ""
869    remove_chars = f"{remove_chars} " if remove_chars else ""
870    from_part = "FROM " if trim_type or remove_chars else ""
871    collation = f" COLLATE {collation}" if collation else ""
872    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
875def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
876    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
879def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
880    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
883def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
884    delim, *rest_args = expression.expressions
885    return self.sql(
886        reduce(
887            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
888            rest_args,
889        )
890    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
893def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
894    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
895    if bad_args:
896        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
897
898    return self.func(
899        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
900    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
903def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
904    bad_args = list(
905        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
906    )
907    if bad_args:
908        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
909
910    return self.func(
911        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
912    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
915def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
916    names = []
917    for agg in aggregations:
918        if isinstance(agg, exp.Alias):
919            names.append(agg.alias)
920        else:
921            """
922            This case corresponds to aggregations without aliases being used as suffixes
923            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
924            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
925            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
926            """
927            agg_all_unquoted = agg.transform(
928                lambda node: (
929                    exp.Identifier(this=node.name, quoted=False)
930                    if isinstance(node, exp.Identifier)
931                    else node
932                )
933            )
934            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
935
936    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
939def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
940    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
944def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
945    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
948def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
949    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
952def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
953    a = self.sql(expression.left)
954    b = self.sql(expression.right)
955    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
958def is_parse_json(expression: exp.Expression) -> bool:
959    return isinstance(expression, exp.ParseJSON) or (
960        isinstance(expression, exp.Cast) and expression.is_type("json")
961    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
964def isnull_to_is_null(args: t.List) -> exp.Expression:
965    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
968def generatedasidentitycolumnconstraint_sql(
969    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
970) -> str:
971    start = self.sql(expression, "start") or "1"
972    increment = self.sql(expression, "increment") or "1"
973    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
976def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
977    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
978        if expression.args.get("count"):
979            self.unsupported(f"Only two arguments are supported in function {name}.")
980
981        return self.func(name, expression.this, expression.expression)
982
983    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
986def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
987    this = expression.this.copy()
988
989    return_type = expression.return_type
990    if return_type.is_type(exp.DataType.Type.DATE):
991        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
992        # can truncate timestamp strings, because some dialects can't cast them to DATE
993        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
994
995    expression.this.replace(exp.cast(this, return_type))
996    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
 999def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1000    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1001        if cast and isinstance(expression, exp.TsOrDsAdd):
1002            expression = ts_or_ds_add_cast(expression)
1003
1004        return self.func(
1005            name,
1006            unit_to_var(expression),
1007            expression.expression,
1008            expression.this,
1009        )
1010
1011    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1014def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1015    unit = expression.args.get("unit")
1016
1017    if isinstance(unit, exp.Placeholder):
1018        return unit
1019    if unit:
1020        return exp.Literal.string(unit.name)
1021    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1024def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1025    unit = expression.args.get("unit")
1026
1027    if isinstance(unit, (exp.Var, exp.Placeholder)):
1028        return unit
1029    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1032def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1033    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1034    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1035    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1036
1037    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1040def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1041    """Remove table refs from columns in when statements."""
1042    alias = expression.this.args.get("alias")
1043
1044    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1045        return self.dialect.normalize_identifier(identifier).name if identifier else None
1046
1047    targets = {normalize(expression.this.this)}
1048
1049    if alias:
1050        targets.add(normalize(alias.this))
1051
1052    for when in expression.expressions:
1053        when.transform(
1054            lambda node: (
1055                exp.column(node.this)
1056                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1057                else node
1058            ),
1059            copy=False,
1060        )
1061
1062    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1065def build_json_extract_path(
1066    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1067) -> t.Callable[[t.List], F]:
1068    def _builder(args: t.List) -> F:
1069        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1070        for arg in args[1:]:
1071            if not isinstance(arg, exp.Literal):
1072                # We use the fallback parser because we can't really transpile non-literals safely
1073                return expr_type.from_arg_list(args)
1074
1075            text = arg.name
1076            if is_int(text):
1077                index = int(text)
1078                segments.append(
1079                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1080                )
1081            else:
1082                segments.append(exp.JSONPathKey(this=text))
1083
1084        # This is done to avoid failing in the expression validator due to the arg count
1085        del args[2:]
1086        return expr_type(
1087            this=seq_get(args, 0),
1088            expression=exp.JSONPath(expressions=segments),
1089            only_json_types=arrow_req_json_type,
1090        )
1091
1092    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1095def json_extract_segments(
1096    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1097) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1098    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1099        path = expression.expression
1100        if not isinstance(path, exp.JSONPath):
1101            return rename_func(name)(self, expression)
1102
1103        segments = []
1104        for segment in path.expressions:
1105            path = self.sql(segment)
1106            if path:
1107                if isinstance(segment, exp.JSONPathPart) and (
1108                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1109                ):
1110                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1111
1112                segments.append(path)
1113
1114        if op:
1115            return f" {op} ".join([self.sql(expression.this), *segments])
1116        return self.func(name, expression.this, *segments)
1117
1118    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1121def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1122    if isinstance(expression.this, exp.JSONPathWildcard):
1123        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1124
1125    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1128def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1129    cond = expression.expression
1130    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1131        alias = cond.expressions[0]
1132        cond = cond.this
1133    elif isinstance(cond, exp.Predicate):
1134        alias = "_u"
1135    else:
1136        self.unsupported("Unsupported filter condition")
1137        return ""
1138
1139    unnest = exp.Unnest(expressions=[expression.this])
1140    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1141    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: sqlglot.expressions.ToNumber) -> str:
1144def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1145    return self.func(
1146        "TO_NUMBER",
1147        expression.this,
1148        expression.args.get("format"),
1149        expression.args.get("nlsparam"),
1150    )