sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 base = seq_get(bases, 0) 114 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 115 base_parser = (getattr(base, "parser_class", Parser),) 116 base_generator = (getattr(base, "generator_class", Generator),) 117 118 klass.tokenizer_class = klass.__dict__.get( 119 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 120 ) 121 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 122 klass.generator_class = klass.__dict__.get( 123 "Generator", type("Generator", base_generator, {}) 124 ) 125 126 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 127 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 128 klass.tokenizer_class._IDENTIFIERS.items() 129 )[0] 130 131 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 132 return next( 133 ( 134 (s, e) 135 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 136 if t == token_type 137 ), 138 (None, None), 139 ) 140 141 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 142 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 143 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 144 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 145 146 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 147 klass.UNESCAPED_SEQUENCES = { 148 "\\a": "\a", 149 "\\b": "\b", 150 "\\f": "\f", 151 "\\n": "\n", 152 "\\r": "\r", 153 "\\t": "\t", 154 "\\v": "\v", 155 "\\\\": "\\", 156 **klass.UNESCAPED_SEQUENCES, 157 } 158 159 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 160 161 if enum not in ("", "bigquery"): 162 klass.generator_class.SELECT_KINDS = () 163 164 if enum not in ("", "databricks", "hive", "spark", "spark2"): 165 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 166 for modifier in ("cluster", "distribute", "sort"): 167 modifier_transforms.pop(modifier, None) 168 169 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 170 171 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 172 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 173 TokenType.ANTI, 174 TokenType.SEMI, 175 } 176 177 return klass 178 179 180class Dialect(metaclass=_Dialect): 181 INDEX_OFFSET = 0 182 """The base index offset for arrays.""" 183 184 WEEK_OFFSET = 0 185 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 186 187 UNNEST_COLUMN_ONLY = False 188 """Whether `UNNEST` table aliases are treated as column aliases.""" 189 190 ALIAS_POST_TABLESAMPLE = False 191 """Whether the table alias comes after tablesample.""" 192 193 TABLESAMPLE_SIZE_IS_PERCENT = False 194 """Whether a size in the table sample clause represents percentage.""" 195 196 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 197 """Specifies the strategy according to which identifiers should be normalized.""" 198 199 IDENTIFIERS_CAN_START_WITH_DIGIT = False 200 """Whether an unquoted identifier can start with a digit.""" 201 202 DPIPE_IS_STRING_CONCAT = True 203 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 204 205 STRICT_STRING_CONCAT = False 206 """Whether `CONCAT`'s arguments must be strings.""" 207 208 SUPPORTS_USER_DEFINED_TYPES = True 209 """Whether user-defined data types are supported.""" 210 211 SUPPORTS_SEMI_ANTI_JOIN = True 212 """Whether `SEMI` or `ANTI` joins are supported.""" 213 214 NORMALIZE_FUNCTIONS: bool | str = "upper" 215 """ 216 Determines how function names are going to be normalized. 217 Possible values: 218 "upper" or True: Convert names to uppercase. 219 "lower": Convert names to lowercase. 220 False: Disables function name normalization. 221 """ 222 223 LOG_BASE_FIRST: t.Optional[bool] = True 224 """ 225 Whether the base comes first in the `LOG` function. 226 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 227 """ 228 229 NULL_ORDERING = "nulls_are_small" 230 """ 231 Default `NULL` ordering method to use if not explicitly set. 232 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 233 """ 234 235 TYPED_DIVISION = False 236 """ 237 Whether the behavior of `a / b` depends on the types of `a` and `b`. 238 False means `a / b` is always float division. 239 True means `a / b` is integer division if both `a` and `b` are integers. 240 """ 241 242 SAFE_DIVISION = False 243 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 244 245 CONCAT_COALESCE = False 246 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 247 248 DATE_FORMAT = "'%Y-%m-%d'" 249 DATEINT_FORMAT = "'%Y%m%d'" 250 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 251 252 TIME_MAPPING: t.Dict[str, str] = {} 253 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 254 255 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 256 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 257 FORMAT_MAPPING: t.Dict[str, str] = {} 258 """ 259 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 260 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 261 """ 262 263 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 264 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 265 266 PSEUDOCOLUMNS: t.Set[str] = set() 267 """ 268 Columns that are auto-generated by the engine corresponding to this dialect. 269 For example, such columns may be excluded from `SELECT *` queries. 270 """ 271 272 PREFER_CTE_ALIAS_COLUMN = False 273 """ 274 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 275 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 276 any projection aliases in the subquery. 277 278 For example, 279 WITH y(c) AS ( 280 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 281 ) SELECT c FROM y; 282 283 will be rewritten as 284 285 WITH y(c) AS ( 286 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 287 ) SELECT c FROM y; 288 """ 289 290 # --- Autofilled --- 291 292 tokenizer_class = Tokenizer 293 parser_class = Parser 294 generator_class = Generator 295 296 # A trie of the time_mapping keys 297 TIME_TRIE: t.Dict = {} 298 FORMAT_TRIE: t.Dict = {} 299 300 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 301 INVERSE_TIME_TRIE: t.Dict = {} 302 303 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 304 305 # Delimiters for string literals and identifiers 306 QUOTE_START = "'" 307 QUOTE_END = "'" 308 IDENTIFIER_START = '"' 309 IDENTIFIER_END = '"' 310 311 # Delimiters for bit, hex, byte and unicode literals 312 BIT_START: t.Optional[str] = None 313 BIT_END: t.Optional[str] = None 314 HEX_START: t.Optional[str] = None 315 HEX_END: t.Optional[str] = None 316 BYTE_START: t.Optional[str] = None 317 BYTE_END: t.Optional[str] = None 318 UNICODE_START: t.Optional[str] = None 319 UNICODE_END: t.Optional[str] = None 320 321 # Separator of COPY statement parameters 322 COPY_PARAMS_ARE_CSV = True 323 324 @classmethod 325 def get_or_raise(cls, dialect: DialectType) -> Dialect: 326 """ 327 Look up a dialect in the global dialect registry and return it if it exists. 328 329 Args: 330 dialect: The target dialect. If this is a string, it can be optionally followed by 331 additional key-value pairs that are separated by commas and are used to specify 332 dialect settings, such as whether the dialect's identifiers are case-sensitive. 333 334 Example: 335 >>> dialect = dialect_class = get_or_raise("duckdb") 336 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 337 338 Returns: 339 The corresponding Dialect instance. 340 """ 341 342 if not dialect: 343 return cls() 344 if isinstance(dialect, _Dialect): 345 return dialect() 346 if isinstance(dialect, Dialect): 347 return dialect 348 if isinstance(dialect, str): 349 try: 350 dialect_name, *kv_pairs = dialect.split(",") 351 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 352 except ValueError: 353 raise ValueError( 354 f"Invalid dialect format: '{dialect}'. " 355 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 356 ) 357 358 result = cls.get(dialect_name.strip()) 359 if not result: 360 from difflib import get_close_matches 361 362 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 363 if similar: 364 similar = f" Did you mean {similar}?" 365 366 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 367 368 return result(**kwargs) 369 370 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 371 372 @classmethod 373 def format_time( 374 cls, expression: t.Optional[str | exp.Expression] 375 ) -> t.Optional[exp.Expression]: 376 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 377 if isinstance(expression, str): 378 return exp.Literal.string( 379 # the time formats are quoted 380 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 381 ) 382 383 if expression and expression.is_string: 384 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 385 386 return expression 387 388 def __init__(self, **kwargs) -> None: 389 normalization_strategy = kwargs.get("normalization_strategy") 390 391 if normalization_strategy is None: 392 self.normalization_strategy = self.NORMALIZATION_STRATEGY 393 else: 394 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 395 396 def __eq__(self, other: t.Any) -> bool: 397 # Does not currently take dialect state into account 398 return type(self) == other 399 400 def __hash__(self) -> int: 401 # Does not currently take dialect state into account 402 return hash(type(self)) 403 404 def normalize_identifier(self, expression: E) -> E: 405 """ 406 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 407 408 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 409 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 410 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 411 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 412 413 There are also dialects like Spark, which are case-insensitive even when quotes are 414 present, and dialects like MySQL, whose resolution rules match those employed by the 415 underlying operating system, for example they may always be case-sensitive in Linux. 416 417 Finally, the normalization behavior of some engines can even be controlled through flags, 418 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 419 420 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 421 that it can analyze queries in the optimizer and successfully capture their semantics. 422 """ 423 if ( 424 isinstance(expression, exp.Identifier) 425 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 426 and ( 427 not expression.quoted 428 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 429 ) 430 ): 431 expression.set( 432 "this", 433 ( 434 expression.this.upper() 435 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 436 else expression.this.lower() 437 ), 438 ) 439 440 return expression 441 442 def case_sensitive(self, text: str) -> bool: 443 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 444 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 445 return False 446 447 unsafe = ( 448 str.islower 449 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 450 else str.isupper 451 ) 452 return any(unsafe(char) for char in text) 453 454 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 455 """Checks if text can be identified given an identify option. 456 457 Args: 458 text: The text to check. 459 identify: 460 `"always"` or `True`: Always returns `True`. 461 `"safe"`: Only returns `True` if the identifier is case-insensitive. 462 463 Returns: 464 Whether the given text can be identified. 465 """ 466 if identify is True or identify == "always": 467 return True 468 469 if identify == "safe": 470 return not self.case_sensitive(text) 471 472 return False 473 474 def quote_identifier(self, expression: E, identify: bool = True) -> E: 475 """ 476 Adds quotes to a given identifier. 477 478 Args: 479 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 480 identify: If set to `False`, the quotes will only be added if the identifier is deemed 481 "unsafe", with respect to its characters and this dialect's normalization strategy. 482 """ 483 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 484 name = expression.this 485 expression.set( 486 "quoted", 487 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 488 ) 489 490 return expression 491 492 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 493 if isinstance(path, exp.Literal): 494 path_text = path.name 495 if path.is_number: 496 path_text = f"[{path_text}]" 497 498 try: 499 return parse_json_path(path_text) 500 except ParseError as e: 501 logger.warning(f"Invalid JSON path syntax. {str(e)}") 502 503 return path 504 505 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 506 return self.parser(**opts).parse(self.tokenize(sql), sql) 507 508 def parse_into( 509 self, expression_type: exp.IntoType, sql: str, **opts 510 ) -> t.List[t.Optional[exp.Expression]]: 511 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 512 513 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 514 return self.generator(**opts).generate(expression, copy=copy) 515 516 def transpile(self, sql: str, **opts) -> t.List[str]: 517 return [ 518 self.generate(expression, copy=False, **opts) if expression else "" 519 for expression in self.parse(sql) 520 ] 521 522 def tokenize(self, sql: str) -> t.List[Token]: 523 return self.tokenizer.tokenize(sql) 524 525 @property 526 def tokenizer(self) -> Tokenizer: 527 if not hasattr(self, "_tokenizer"): 528 self._tokenizer = self.tokenizer_class(dialect=self) 529 return self._tokenizer 530 531 def parser(self, **opts) -> Parser: 532 return self.parser_class(dialect=self, **opts) 533 534 def generator(self, **opts) -> Generator: 535 return self.generator_class(dialect=self, **opts) 536 537 538DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 539 540 541def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 542 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 543 544 545def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 546 if expression.args.get("accuracy"): 547 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 548 return self.func("APPROX_COUNT_DISTINCT", expression.this) 549 550 551def if_sql( 552 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 553) -> t.Callable[[Generator, exp.If], str]: 554 def _if_sql(self: Generator, expression: exp.If) -> str: 555 return self.func( 556 name, 557 expression.this, 558 expression.args.get("true"), 559 expression.args.get("false") or false_value, 560 ) 561 562 return _if_sql 563 564 565def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 566 this = expression.this 567 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 568 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 569 570 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 571 572 573def inline_array_sql(self: Generator, expression: exp.Array) -> str: 574 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 575 576 577def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 578 elem = seq_get(expression.expressions, 0) 579 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 580 return self.func("ARRAY", elem) 581 return inline_array_sql(self, expression) 582 583 584def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 585 return self.like_sql( 586 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 587 ) 588 589 590def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 591 zone = self.sql(expression, "this") 592 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 593 594 595def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 596 if expression.args.get("recursive"): 597 self.unsupported("Recursive CTEs are unsupported") 598 expression.args["recursive"] = False 599 return self.with_sql(expression) 600 601 602def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 603 n = self.sql(expression, "this") 604 d = self.sql(expression, "expression") 605 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 606 607 608def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 609 self.unsupported("TABLESAMPLE unsupported") 610 return self.sql(expression.this) 611 612 613def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 614 self.unsupported("PIVOT unsupported") 615 return "" 616 617 618def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 619 return self.cast_sql(expression) 620 621 622def no_comment_column_constraint_sql( 623 self: Generator, expression: exp.CommentColumnConstraint 624) -> str: 625 self.unsupported("CommentColumnConstraint unsupported") 626 return "" 627 628 629def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 630 self.unsupported("MAP_FROM_ENTRIES unsupported") 631 return "" 632 633 634def str_position_sql( 635 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 636) -> str: 637 this = self.sql(expression, "this") 638 substr = self.sql(expression, "substr") 639 position = self.sql(expression, "position") 640 instance = expression.args.get("instance") if generate_instance else None 641 position_offset = "" 642 643 if position: 644 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 645 this = self.func("SUBSTR", this, position) 646 position_offset = f" + {position} - 1" 647 648 return self.func("STRPOS", this, substr, instance) + position_offset 649 650 651def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 652 return ( 653 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 654 ) 655 656 657def var_map_sql( 658 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 659) -> str: 660 keys = expression.args["keys"] 661 values = expression.args["values"] 662 663 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 664 self.unsupported("Cannot convert array columns into map.") 665 return self.func(map_func_name, keys, values) 666 667 args = [] 668 for key, value in zip(keys.expressions, values.expressions): 669 args.append(self.sql(key)) 670 args.append(self.sql(value)) 671 672 return self.func(map_func_name, *args) 673 674 675def build_formatted_time( 676 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 677) -> t.Callable[[t.List], E]: 678 """Helper used for time expressions. 679 680 Args: 681 exp_class: the expression class to instantiate. 682 dialect: target sql dialect. 683 default: the default format, True being time. 684 685 Returns: 686 A callable that can be used to return the appropriately formatted time expression. 687 """ 688 689 def _builder(args: t.List): 690 return exp_class( 691 this=seq_get(args, 0), 692 format=Dialect[dialect].format_time( 693 seq_get(args, 1) 694 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 695 ), 696 ) 697 698 return _builder 699 700 701def time_format( 702 dialect: DialectType = None, 703) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 704 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 705 """ 706 Returns the time format for a given expression, unless it's equivalent 707 to the default time format of the dialect of interest. 708 """ 709 time_format = self.format_time(expression) 710 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 711 712 return _time_format 713 714 715def build_date_delta( 716 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 717) -> t.Callable[[t.List], E]: 718 def _builder(args: t.List) -> E: 719 unit_based = len(args) == 3 720 this = args[2] if unit_based else seq_get(args, 0) 721 unit = args[0] if unit_based else exp.Literal.string("DAY") 722 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 723 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 724 725 return _builder 726 727 728def build_date_delta_with_interval( 729 expression_class: t.Type[E], 730) -> t.Callable[[t.List], t.Optional[E]]: 731 def _builder(args: t.List) -> t.Optional[E]: 732 if len(args) < 2: 733 return None 734 735 interval = args[1] 736 737 if not isinstance(interval, exp.Interval): 738 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 739 740 expression = interval.this 741 if expression and expression.is_string: 742 expression = exp.Literal.number(expression.this) 743 744 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 745 746 return _builder 747 748 749def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 750 unit = seq_get(args, 0) 751 this = seq_get(args, 1) 752 753 if isinstance(this, exp.Cast) and this.is_type("date"): 754 return exp.DateTrunc(unit=unit, this=this) 755 return exp.TimestampTrunc(this=this, unit=unit) 756 757 758def date_add_interval_sql( 759 data_type: str, kind: str 760) -> t.Callable[[Generator, exp.Expression], str]: 761 def func(self: Generator, expression: exp.Expression) -> str: 762 this = self.sql(expression, "this") 763 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 764 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 765 766 return func 767 768 769def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 770 return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) 771 772 773def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 774 if not expression.expression: 775 from sqlglot.optimizer.annotate_types import annotate_types 776 777 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 778 return self.sql(exp.cast(expression.this, target_type)) 779 if expression.text("expression").lower() in TIMEZONES: 780 return self.sql( 781 exp.AtTimeZone( 782 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 783 zone=expression.expression, 784 ) 785 ) 786 return self.func("TIMESTAMP", expression.this, expression.expression) 787 788 789def locate_to_strposition(args: t.List) -> exp.Expression: 790 return exp.StrPosition( 791 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 792 ) 793 794 795def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 796 return self.func( 797 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 798 ) 799 800 801def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 802 return self.sql( 803 exp.Substring( 804 this=expression.this, start=exp.Literal.number(1), length=expression.expression 805 ) 806 ) 807 808 809def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 810 return self.sql( 811 exp.Substring( 812 this=expression.this, 813 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 814 ) 815 ) 816 817 818def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 819 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 820 821 822def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 823 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 824 825 826# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 827def encode_decode_sql( 828 self: Generator, expression: exp.Expression, name: str, replace: bool = True 829) -> str: 830 charset = expression.args.get("charset") 831 if charset and charset.name.lower() != "utf-8": 832 self.unsupported(f"Expected utf-8 character set, got {charset}.") 833 834 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 835 836 837def min_or_least(self: Generator, expression: exp.Min) -> str: 838 name = "LEAST" if expression.expressions else "MIN" 839 return rename_func(name)(self, expression) 840 841 842def max_or_greatest(self: Generator, expression: exp.Max) -> str: 843 name = "GREATEST" if expression.expressions else "MAX" 844 return rename_func(name)(self, expression) 845 846 847def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 848 cond = expression.this 849 850 if isinstance(expression.this, exp.Distinct): 851 cond = expression.this.expressions[0] 852 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 853 854 return self.func("sum", exp.func("if", cond, 1, 0)) 855 856 857def trim_sql(self: Generator, expression: exp.Trim) -> str: 858 target = self.sql(expression, "this") 859 trim_type = self.sql(expression, "position") 860 remove_chars = self.sql(expression, "expression") 861 collation = self.sql(expression, "collation") 862 863 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 864 if not remove_chars and not collation: 865 return self.trim_sql(expression) 866 867 trim_type = f"{trim_type} " if trim_type else "" 868 remove_chars = f"{remove_chars} " if remove_chars else "" 869 from_part = "FROM " if trim_type or remove_chars else "" 870 collation = f" COLLATE {collation}" if collation else "" 871 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 872 873 874def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 875 return self.func("STRPTIME", expression.this, self.format_time(expression)) 876 877 878def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 879 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 880 881 882def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 883 delim, *rest_args = expression.expressions 884 return self.sql( 885 reduce( 886 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 887 rest_args, 888 ) 889 ) 890 891 892def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 893 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 894 if bad_args: 895 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 896 897 return self.func( 898 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 899 ) 900 901 902def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 903 bad_args = list( 904 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 905 ) 906 if bad_args: 907 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 908 909 return self.func( 910 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 911 ) 912 913 914def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 915 names = [] 916 for agg in aggregations: 917 if isinstance(agg, exp.Alias): 918 names.append(agg.alias) 919 else: 920 """ 921 This case corresponds to aggregations without aliases being used as suffixes 922 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 923 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 924 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 925 """ 926 agg_all_unquoted = agg.transform( 927 lambda node: ( 928 exp.Identifier(this=node.name, quoted=False) 929 if isinstance(node, exp.Identifier) 930 else node 931 ) 932 ) 933 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 934 935 return names 936 937 938def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 939 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 940 941 942# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 943def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 944 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 945 946 947def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 948 return self.func("MAX", expression.this) 949 950 951def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 952 a = self.sql(expression.left) 953 b = self.sql(expression.right) 954 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 955 956 957def is_parse_json(expression: exp.Expression) -> bool: 958 return isinstance(expression, exp.ParseJSON) or ( 959 isinstance(expression, exp.Cast) and expression.is_type("json") 960 ) 961 962 963def isnull_to_is_null(args: t.List) -> exp.Expression: 964 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 965 966 967def generatedasidentitycolumnconstraint_sql( 968 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 969) -> str: 970 start = self.sql(expression, "start") or "1" 971 increment = self.sql(expression, "increment") or "1" 972 return f"IDENTITY({start}, {increment})" 973 974 975def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 976 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 977 if expression.args.get("count"): 978 self.unsupported(f"Only two arguments are supported in function {name}.") 979 980 return self.func(name, expression.this, expression.expression) 981 982 return _arg_max_or_min_sql 983 984 985def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 986 this = expression.this.copy() 987 988 return_type = expression.return_type 989 if return_type.is_type(exp.DataType.Type.DATE): 990 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 991 # can truncate timestamp strings, because some dialects can't cast them to DATE 992 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 993 994 expression.this.replace(exp.cast(this, return_type)) 995 return expression 996 997 998def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 999 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1000 if cast and isinstance(expression, exp.TsOrDsAdd): 1001 expression = ts_or_ds_add_cast(expression) 1002 1003 return self.func( 1004 name, 1005 unit_to_var(expression), 1006 expression.expression, 1007 expression.this, 1008 ) 1009 1010 return _delta_sql 1011 1012 1013def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1014 unit = expression.args.get("unit") 1015 1016 if isinstance(unit, exp.Placeholder): 1017 return unit 1018 if unit: 1019 return exp.Literal.string(unit.name) 1020 return exp.Literal.string(default) if default else None 1021 1022 1023def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1024 unit = expression.args.get("unit") 1025 1026 if isinstance(unit, (exp.Var, exp.Placeholder)): 1027 return unit 1028 return exp.Var(this=default) if default else None 1029 1030 1031def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1032 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1033 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1034 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1035 1036 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1037 1038 1039def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1040 """Remove table refs from columns in when statements.""" 1041 alias = expression.this.args.get("alias") 1042 1043 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1044 return self.dialect.normalize_identifier(identifier).name if identifier else None 1045 1046 targets = {normalize(expression.this.this)} 1047 1048 if alias: 1049 targets.add(normalize(alias.this)) 1050 1051 for when in expression.expressions: 1052 when.transform( 1053 lambda node: ( 1054 exp.column(node.this) 1055 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1056 else node 1057 ), 1058 copy=False, 1059 ) 1060 1061 return self.merge_sql(expression) 1062 1063 1064def build_json_extract_path( 1065 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1066) -> t.Callable[[t.List], F]: 1067 def _builder(args: t.List) -> F: 1068 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1069 for arg in args[1:]: 1070 if not isinstance(arg, exp.Literal): 1071 # We use the fallback parser because we can't really transpile non-literals safely 1072 return expr_type.from_arg_list(args) 1073 1074 text = arg.name 1075 if is_int(text): 1076 index = int(text) 1077 segments.append( 1078 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1079 ) 1080 else: 1081 segments.append(exp.JSONPathKey(this=text)) 1082 1083 # This is done to avoid failing in the expression validator due to the arg count 1084 del args[2:] 1085 return expr_type( 1086 this=seq_get(args, 0), 1087 expression=exp.JSONPath(expressions=segments), 1088 only_json_types=arrow_req_json_type, 1089 ) 1090 1091 return _builder 1092 1093 1094def json_extract_segments( 1095 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1096) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1097 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1098 path = expression.expression 1099 if not isinstance(path, exp.JSONPath): 1100 return rename_func(name)(self, expression) 1101 1102 segments = [] 1103 for segment in path.expressions: 1104 path = self.sql(segment) 1105 if path: 1106 if isinstance(segment, exp.JSONPathPart) and ( 1107 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1108 ): 1109 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1110 1111 segments.append(path) 1112 1113 if op: 1114 return f" {op} ".join([self.sql(expression.this), *segments]) 1115 return self.func(name, expression.this, *segments) 1116 1117 return _json_extract_segments 1118 1119 1120def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1121 if isinstance(expression.this, exp.JSONPathWildcard): 1122 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1123 1124 return expression.name 1125 1126 1127def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1128 cond = expression.expression 1129 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1130 alias = cond.expressions[0] 1131 cond = cond.this 1132 elif isinstance(cond, exp.Predicate): 1133 alias = "_u" 1134 else: 1135 self.unsupported("Unsupported filter condition") 1136 return "" 1137 1138 unnest = exp.Unnest(expressions=[expression.this]) 1139 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1140 return self.sql(exp.Array(expressions=[filtered])) 1141 1142 1143def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: 1144 return self.func( 1145 "TO_NUMBER", 1146 expression.this, 1147 expression.args.get("format"), 1148 expression.args.get("nlsparam"), 1149 )
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
181class Dialect(metaclass=_Dialect): 182 INDEX_OFFSET = 0 183 """The base index offset for arrays.""" 184 185 WEEK_OFFSET = 0 186 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 187 188 UNNEST_COLUMN_ONLY = False 189 """Whether `UNNEST` table aliases are treated as column aliases.""" 190 191 ALIAS_POST_TABLESAMPLE = False 192 """Whether the table alias comes after tablesample.""" 193 194 TABLESAMPLE_SIZE_IS_PERCENT = False 195 """Whether a size in the table sample clause represents percentage.""" 196 197 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 198 """Specifies the strategy according to which identifiers should be normalized.""" 199 200 IDENTIFIERS_CAN_START_WITH_DIGIT = False 201 """Whether an unquoted identifier can start with a digit.""" 202 203 DPIPE_IS_STRING_CONCAT = True 204 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 205 206 STRICT_STRING_CONCAT = False 207 """Whether `CONCAT`'s arguments must be strings.""" 208 209 SUPPORTS_USER_DEFINED_TYPES = True 210 """Whether user-defined data types are supported.""" 211 212 SUPPORTS_SEMI_ANTI_JOIN = True 213 """Whether `SEMI` or `ANTI` joins are supported.""" 214 215 NORMALIZE_FUNCTIONS: bool | str = "upper" 216 """ 217 Determines how function names are going to be normalized. 218 Possible values: 219 "upper" or True: Convert names to uppercase. 220 "lower": Convert names to lowercase. 221 False: Disables function name normalization. 222 """ 223 224 LOG_BASE_FIRST: t.Optional[bool] = True 225 """ 226 Whether the base comes first in the `LOG` function. 227 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 228 """ 229 230 NULL_ORDERING = "nulls_are_small" 231 """ 232 Default `NULL` ordering method to use if not explicitly set. 233 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 234 """ 235 236 TYPED_DIVISION = False 237 """ 238 Whether the behavior of `a / b` depends on the types of `a` and `b`. 239 False means `a / b` is always float division. 240 True means `a / b` is integer division if both `a` and `b` are integers. 241 """ 242 243 SAFE_DIVISION = False 244 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 245 246 CONCAT_COALESCE = False 247 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 248 249 DATE_FORMAT = "'%Y-%m-%d'" 250 DATEINT_FORMAT = "'%Y%m%d'" 251 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 252 253 TIME_MAPPING: t.Dict[str, str] = {} 254 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 255 256 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 257 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 258 FORMAT_MAPPING: t.Dict[str, str] = {} 259 """ 260 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 261 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 262 """ 263 264 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 265 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 266 267 PSEUDOCOLUMNS: t.Set[str] = set() 268 """ 269 Columns that are auto-generated by the engine corresponding to this dialect. 270 For example, such columns may be excluded from `SELECT *` queries. 271 """ 272 273 PREFER_CTE_ALIAS_COLUMN = False 274 """ 275 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 276 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 277 any projection aliases in the subquery. 278 279 For example, 280 WITH y(c) AS ( 281 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 282 ) SELECT c FROM y; 283 284 will be rewritten as 285 286 WITH y(c) AS ( 287 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 288 ) SELECT c FROM y; 289 """ 290 291 # --- Autofilled --- 292 293 tokenizer_class = Tokenizer 294 parser_class = Parser 295 generator_class = Generator 296 297 # A trie of the time_mapping keys 298 TIME_TRIE: t.Dict = {} 299 FORMAT_TRIE: t.Dict = {} 300 301 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 302 INVERSE_TIME_TRIE: t.Dict = {} 303 304 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 305 306 # Delimiters for string literals and identifiers 307 QUOTE_START = "'" 308 QUOTE_END = "'" 309 IDENTIFIER_START = '"' 310 IDENTIFIER_END = '"' 311 312 # Delimiters for bit, hex, byte and unicode literals 313 BIT_START: t.Optional[str] = None 314 BIT_END: t.Optional[str] = None 315 HEX_START: t.Optional[str] = None 316 HEX_END: t.Optional[str] = None 317 BYTE_START: t.Optional[str] = None 318 BYTE_END: t.Optional[str] = None 319 UNICODE_START: t.Optional[str] = None 320 UNICODE_END: t.Optional[str] = None 321 322 # Separator of COPY statement parameters 323 COPY_PARAMS_ARE_CSV = True 324 325 @classmethod 326 def get_or_raise(cls, dialect: DialectType) -> Dialect: 327 """ 328 Look up a dialect in the global dialect registry and return it if it exists. 329 330 Args: 331 dialect: The target dialect. If this is a string, it can be optionally followed by 332 additional key-value pairs that are separated by commas and are used to specify 333 dialect settings, such as whether the dialect's identifiers are case-sensitive. 334 335 Example: 336 >>> dialect = dialect_class = get_or_raise("duckdb") 337 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 338 339 Returns: 340 The corresponding Dialect instance. 341 """ 342 343 if not dialect: 344 return cls() 345 if isinstance(dialect, _Dialect): 346 return dialect() 347 if isinstance(dialect, Dialect): 348 return dialect 349 if isinstance(dialect, str): 350 try: 351 dialect_name, *kv_pairs = dialect.split(",") 352 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 353 except ValueError: 354 raise ValueError( 355 f"Invalid dialect format: '{dialect}'. " 356 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 357 ) 358 359 result = cls.get(dialect_name.strip()) 360 if not result: 361 from difflib import get_close_matches 362 363 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 364 if similar: 365 similar = f" Did you mean {similar}?" 366 367 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 368 369 return result(**kwargs) 370 371 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 372 373 @classmethod 374 def format_time( 375 cls, expression: t.Optional[str | exp.Expression] 376 ) -> t.Optional[exp.Expression]: 377 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 378 if isinstance(expression, str): 379 return exp.Literal.string( 380 # the time formats are quoted 381 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 382 ) 383 384 if expression and expression.is_string: 385 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 386 387 return expression 388 389 def __init__(self, **kwargs) -> None: 390 normalization_strategy = kwargs.get("normalization_strategy") 391 392 if normalization_strategy is None: 393 self.normalization_strategy = self.NORMALIZATION_STRATEGY 394 else: 395 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 396 397 def __eq__(self, other: t.Any) -> bool: 398 # Does not currently take dialect state into account 399 return type(self) == other 400 401 def __hash__(self) -> int: 402 # Does not currently take dialect state into account 403 return hash(type(self)) 404 405 def normalize_identifier(self, expression: E) -> E: 406 """ 407 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 408 409 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 410 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 411 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 412 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 413 414 There are also dialects like Spark, which are case-insensitive even when quotes are 415 present, and dialects like MySQL, whose resolution rules match those employed by the 416 underlying operating system, for example they may always be case-sensitive in Linux. 417 418 Finally, the normalization behavior of some engines can even be controlled through flags, 419 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 420 421 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 422 that it can analyze queries in the optimizer and successfully capture their semantics. 423 """ 424 if ( 425 isinstance(expression, exp.Identifier) 426 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 427 and ( 428 not expression.quoted 429 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 430 ) 431 ): 432 expression.set( 433 "this", 434 ( 435 expression.this.upper() 436 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 437 else expression.this.lower() 438 ), 439 ) 440 441 return expression 442 443 def case_sensitive(self, text: str) -> bool: 444 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 445 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 446 return False 447 448 unsafe = ( 449 str.islower 450 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 451 else str.isupper 452 ) 453 return any(unsafe(char) for char in text) 454 455 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 456 """Checks if text can be identified given an identify option. 457 458 Args: 459 text: The text to check. 460 identify: 461 `"always"` or `True`: Always returns `True`. 462 `"safe"`: Only returns `True` if the identifier is case-insensitive. 463 464 Returns: 465 Whether the given text can be identified. 466 """ 467 if identify is True or identify == "always": 468 return True 469 470 if identify == "safe": 471 return not self.case_sensitive(text) 472 473 return False 474 475 def quote_identifier(self, expression: E, identify: bool = True) -> E: 476 """ 477 Adds quotes to a given identifier. 478 479 Args: 480 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 481 identify: If set to `False`, the quotes will only be added if the identifier is deemed 482 "unsafe", with respect to its characters and this dialect's normalization strategy. 483 """ 484 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 485 name = expression.this 486 expression.set( 487 "quoted", 488 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 489 ) 490 491 return expression 492 493 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 494 if isinstance(path, exp.Literal): 495 path_text = path.name 496 if path.is_number: 497 path_text = f"[{path_text}]" 498 499 try: 500 return parse_json_path(path_text) 501 except ParseError as e: 502 logger.warning(f"Invalid JSON path syntax. {str(e)}") 503 504 return path 505 506 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 507 return self.parser(**opts).parse(self.tokenize(sql), sql) 508 509 def parse_into( 510 self, expression_type: exp.IntoType, sql: str, **opts 511 ) -> t.List[t.Optional[exp.Expression]]: 512 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 513 514 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 515 return self.generator(**opts).generate(expression, copy=copy) 516 517 def transpile(self, sql: str, **opts) -> t.List[str]: 518 return [ 519 self.generate(expression, copy=False, **opts) if expression else "" 520 for expression in self.parse(sql) 521 ] 522 523 def tokenize(self, sql: str) -> t.List[Token]: 524 return self.tokenizer.tokenize(sql) 525 526 @property 527 def tokenizer(self) -> Tokenizer: 528 if not hasattr(self, "_tokenizer"): 529 self._tokenizer = self.tokenizer_class(dialect=self) 530 return self._tokenizer 531 532 def parser(self, **opts) -> Parser: 533 return self.parser_class(dialect=self, **opts) 534 535 def generator(self, **opts) -> Generator: 536 return self.generator_class(dialect=self, **opts)
389 def __init__(self, **kwargs) -> None: 390 normalization_strategy = kwargs.get("normalization_strategy") 391 392 if normalization_strategy is None: 393 self.normalization_strategy = self.NORMALIZATION_STRATEGY 394 else: 395 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
325 @classmethod 326 def get_or_raise(cls, dialect: DialectType) -> Dialect: 327 """ 328 Look up a dialect in the global dialect registry and return it if it exists. 329 330 Args: 331 dialect: The target dialect. If this is a string, it can be optionally followed by 332 additional key-value pairs that are separated by commas and are used to specify 333 dialect settings, such as whether the dialect's identifiers are case-sensitive. 334 335 Example: 336 >>> dialect = dialect_class = get_or_raise("duckdb") 337 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 338 339 Returns: 340 The corresponding Dialect instance. 341 """ 342 343 if not dialect: 344 return cls() 345 if isinstance(dialect, _Dialect): 346 return dialect() 347 if isinstance(dialect, Dialect): 348 return dialect 349 if isinstance(dialect, str): 350 try: 351 dialect_name, *kv_pairs = dialect.split(",") 352 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 353 except ValueError: 354 raise ValueError( 355 f"Invalid dialect format: '{dialect}'. " 356 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 357 ) 358 359 result = cls.get(dialect_name.strip()) 360 if not result: 361 from difflib import get_close_matches 362 363 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 364 if similar: 365 similar = f" Did you mean {similar}?" 366 367 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 368 369 return result(**kwargs) 370 371 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
373 @classmethod 374 def format_time( 375 cls, expression: t.Optional[str | exp.Expression] 376 ) -> t.Optional[exp.Expression]: 377 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 378 if isinstance(expression, str): 379 return exp.Literal.string( 380 # the time formats are quoted 381 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 382 ) 383 384 if expression and expression.is_string: 385 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 386 387 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
405 def normalize_identifier(self, expression: E) -> E: 406 """ 407 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 408 409 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 410 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 411 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 412 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 413 414 There are also dialects like Spark, which are case-insensitive even when quotes are 415 present, and dialects like MySQL, whose resolution rules match those employed by the 416 underlying operating system, for example they may always be case-sensitive in Linux. 417 418 Finally, the normalization behavior of some engines can even be controlled through flags, 419 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 420 421 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 422 that it can analyze queries in the optimizer and successfully capture their semantics. 423 """ 424 if ( 425 isinstance(expression, exp.Identifier) 426 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 427 and ( 428 not expression.quoted 429 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 430 ) 431 ): 432 expression.set( 433 "this", 434 ( 435 expression.this.upper() 436 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 437 else expression.this.lower() 438 ), 439 ) 440 441 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
443 def case_sensitive(self, text: str) -> bool: 444 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 445 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 446 return False 447 448 unsafe = ( 449 str.islower 450 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 451 else str.isupper 452 ) 453 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
455 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 456 """Checks if text can be identified given an identify option. 457 458 Args: 459 text: The text to check. 460 identify: 461 `"always"` or `True`: Always returns `True`. 462 `"safe"`: Only returns `True` if the identifier is case-insensitive. 463 464 Returns: 465 Whether the given text can be identified. 466 """ 467 if identify is True or identify == "always": 468 return True 469 470 if identify == "safe": 471 return not self.case_sensitive(text) 472 473 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
475 def quote_identifier(self, expression: E, identify: bool = True) -> E: 476 """ 477 Adds quotes to a given identifier. 478 479 Args: 480 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 481 identify: If set to `False`, the quotes will only be added if the identifier is deemed 482 "unsafe", with respect to its characters and this dialect's normalization strategy. 483 """ 484 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 485 name = expression.this 486 expression.set( 487 "quoted", 488 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 489 ) 490 491 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
493 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 494 if isinstance(path, exp.Literal): 495 path_text = path.name 496 if path.is_number: 497 path_text = f"[{path_text}]" 498 499 try: 500 return parse_json_path(path_text) 501 except ParseError as e: 502 logger.warning(f"Invalid JSON path syntax. {str(e)}") 503 504 return path
552def if_sql( 553 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 554) -> t.Callable[[Generator, exp.If], str]: 555 def _if_sql(self: Generator, expression: exp.If) -> str: 556 return self.func( 557 name, 558 expression.this, 559 expression.args.get("true"), 560 expression.args.get("false") or false_value, 561 ) 562 563 return _if_sql
566def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 567 this = expression.this 568 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 569 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 570 571 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
635def str_position_sql( 636 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 637) -> str: 638 this = self.sql(expression, "this") 639 substr = self.sql(expression, "substr") 640 position = self.sql(expression, "position") 641 instance = expression.args.get("instance") if generate_instance else None 642 position_offset = "" 643 644 if position: 645 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 646 this = self.func("SUBSTR", this, position) 647 position_offset = f" + {position} - 1" 648 649 return self.func("STRPOS", this, substr, instance) + position_offset
658def var_map_sql( 659 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 660) -> str: 661 keys = expression.args["keys"] 662 values = expression.args["values"] 663 664 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 665 self.unsupported("Cannot convert array columns into map.") 666 return self.func(map_func_name, keys, values) 667 668 args = [] 669 for key, value in zip(keys.expressions, values.expressions): 670 args.append(self.sql(key)) 671 args.append(self.sql(value)) 672 673 return self.func(map_func_name, *args)
676def build_formatted_time( 677 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 678) -> t.Callable[[t.List], E]: 679 """Helper used for time expressions. 680 681 Args: 682 exp_class: the expression class to instantiate. 683 dialect: target sql dialect. 684 default: the default format, True being time. 685 686 Returns: 687 A callable that can be used to return the appropriately formatted time expression. 688 """ 689 690 def _builder(args: t.List): 691 return exp_class( 692 this=seq_get(args, 0), 693 format=Dialect[dialect].format_time( 694 seq_get(args, 1) 695 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 696 ), 697 ) 698 699 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
702def time_format( 703 dialect: DialectType = None, 704) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 705 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 706 """ 707 Returns the time format for a given expression, unless it's equivalent 708 to the default time format of the dialect of interest. 709 """ 710 time_format = self.format_time(expression) 711 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 712 713 return _time_format
716def build_date_delta( 717 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 718) -> t.Callable[[t.List], E]: 719 def _builder(args: t.List) -> E: 720 unit_based = len(args) == 3 721 this = args[2] if unit_based else seq_get(args, 0) 722 unit = args[0] if unit_based else exp.Literal.string("DAY") 723 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 724 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 725 726 return _builder
729def build_date_delta_with_interval( 730 expression_class: t.Type[E], 731) -> t.Callable[[t.List], t.Optional[E]]: 732 def _builder(args: t.List) -> t.Optional[E]: 733 if len(args) < 2: 734 return None 735 736 interval = args[1] 737 738 if not isinstance(interval, exp.Interval): 739 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 740 741 expression = interval.this 742 if expression and expression.is_string: 743 expression = exp.Literal.number(expression.this) 744 745 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 746 747 return _builder
759def date_add_interval_sql( 760 data_type: str, kind: str 761) -> t.Callable[[Generator, exp.Expression], str]: 762 def func(self: Generator, expression: exp.Expression) -> str: 763 this = self.sql(expression, "this") 764 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 765 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 766 767 return func
774def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 775 if not expression.expression: 776 from sqlglot.optimizer.annotate_types import annotate_types 777 778 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 779 return self.sql(exp.cast(expression.this, target_type)) 780 if expression.text("expression").lower() in TIMEZONES: 781 return self.sql( 782 exp.AtTimeZone( 783 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 784 zone=expression.expression, 785 ) 786 ) 787 return self.func("TIMESTAMP", expression.this, expression.expression)
828def encode_decode_sql( 829 self: Generator, expression: exp.Expression, name: str, replace: bool = True 830) -> str: 831 charset = expression.args.get("charset") 832 if charset and charset.name.lower() != "utf-8": 833 self.unsupported(f"Expected utf-8 character set, got {charset}.") 834 835 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
848def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 849 cond = expression.this 850 851 if isinstance(expression.this, exp.Distinct): 852 cond = expression.this.expressions[0] 853 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 854 855 return self.func("sum", exp.func("if", cond, 1, 0))
858def trim_sql(self: Generator, expression: exp.Trim) -> str: 859 target = self.sql(expression, "this") 860 trim_type = self.sql(expression, "position") 861 remove_chars = self.sql(expression, "expression") 862 collation = self.sql(expression, "collation") 863 864 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 865 if not remove_chars and not collation: 866 return self.trim_sql(expression) 867 868 trim_type = f"{trim_type} " if trim_type else "" 869 remove_chars = f"{remove_chars} " if remove_chars else "" 870 from_part = "FROM " if trim_type or remove_chars else "" 871 collation = f" COLLATE {collation}" if collation else "" 872 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
893def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 894 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 895 if bad_args: 896 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 897 898 return self.func( 899 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 900 )
903def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 904 bad_args = list( 905 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 906 ) 907 if bad_args: 908 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 909 910 return self.func( 911 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 912 )
915def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 916 names = [] 917 for agg in aggregations: 918 if isinstance(agg, exp.Alias): 919 names.append(agg.alias) 920 else: 921 """ 922 This case corresponds to aggregations without aliases being used as suffixes 923 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 924 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 925 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 926 """ 927 agg_all_unquoted = agg.transform( 928 lambda node: ( 929 exp.Identifier(this=node.name, quoted=False) 930 if isinstance(node, exp.Identifier) 931 else node 932 ) 933 ) 934 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 935 936 return names
976def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 977 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 978 if expression.args.get("count"): 979 self.unsupported(f"Only two arguments are supported in function {name}.") 980 981 return self.func(name, expression.this, expression.expression) 982 983 return _arg_max_or_min_sql
986def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 987 this = expression.this.copy() 988 989 return_type = expression.return_type 990 if return_type.is_type(exp.DataType.Type.DATE): 991 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 992 # can truncate timestamp strings, because some dialects can't cast them to DATE 993 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 994 995 expression.this.replace(exp.cast(this, return_type)) 996 return expression
999def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1000 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1001 if cast and isinstance(expression, exp.TsOrDsAdd): 1002 expression = ts_or_ds_add_cast(expression) 1003 1004 return self.func( 1005 name, 1006 unit_to_var(expression), 1007 expression.expression, 1008 expression.this, 1009 ) 1010 1011 return _delta_sql
1014def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1015 unit = expression.args.get("unit") 1016 1017 if isinstance(unit, exp.Placeholder): 1018 return unit 1019 if unit: 1020 return exp.Literal.string(unit.name) 1021 return exp.Literal.string(default) if default else None
1032def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1033 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1034 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1035 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1036 1037 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1040def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1041 """Remove table refs from columns in when statements.""" 1042 alias = expression.this.args.get("alias") 1043 1044 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1045 return self.dialect.normalize_identifier(identifier).name if identifier else None 1046 1047 targets = {normalize(expression.this.this)} 1048 1049 if alias: 1050 targets.add(normalize(alias.this)) 1051 1052 for when in expression.expressions: 1053 when.transform( 1054 lambda node: ( 1055 exp.column(node.this) 1056 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1057 else node 1058 ), 1059 copy=False, 1060 ) 1061 1062 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1065def build_json_extract_path( 1066 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1067) -> t.Callable[[t.List], F]: 1068 def _builder(args: t.List) -> F: 1069 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1070 for arg in args[1:]: 1071 if not isinstance(arg, exp.Literal): 1072 # We use the fallback parser because we can't really transpile non-literals safely 1073 return expr_type.from_arg_list(args) 1074 1075 text = arg.name 1076 if is_int(text): 1077 index = int(text) 1078 segments.append( 1079 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1080 ) 1081 else: 1082 segments.append(exp.JSONPathKey(this=text)) 1083 1084 # This is done to avoid failing in the expression validator due to the arg count 1085 del args[2:] 1086 return expr_type( 1087 this=seq_get(args, 0), 1088 expression=exp.JSONPath(expressions=segments), 1089 only_json_types=arrow_req_json_type, 1090 ) 1091 1092 return _builder
1095def json_extract_segments( 1096 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1097) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1098 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1099 path = expression.expression 1100 if not isinstance(path, exp.JSONPath): 1101 return rename_func(name)(self, expression) 1102 1103 segments = [] 1104 for segment in path.expressions: 1105 path = self.sql(segment) 1106 if path: 1107 if isinstance(segment, exp.JSONPathPart) and ( 1108 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1109 ): 1110 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1111 1112 segments.append(path) 1113 1114 if op: 1115 return f" {op} ".join([self.sql(expression.this), *segments]) 1116 return self.func(name, expression.this, *segments) 1117 1118 return _json_extract_segments
1128def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1129 cond = expression.expression 1130 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1131 alias = cond.expressions[0] 1132 cond = cond.this 1133 elif isinstance(cond, exp.Predicate): 1134 alias = "_u" 1135 else: 1136 self.unsupported("Unsupported filter condition") 1137 return "" 1138 1139 unnest = exp.Unnest(expressions=[expression.this]) 1140 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1141 return self.sql(exp.Array(expressions=[filtered]))