Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23# Final means that an expression should not be simplified
  24FINAL = "final"
  25
  26# Value ranges for byte-sized signed/unsigned integers
  27TINYINT_MIN = -128
  28TINYINT_MAX = 127
  29UTINYINT_MIN = 0
  30UTINYINT_MAX = 255
  31
  32
  33class UnsupportedUnit(Exception):
  34    pass
  35
  36
  37def simplify(
  38    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  39):
  40    """
  41    Rewrite sqlglot AST to simplify expressions.
  42
  43    Example:
  44        >>> import sqlglot
  45        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  46        >>> simplify(expression).sql()
  47        'TRUE'
  48
  49    Args:
  50        expression (sqlglot.Expression): expression to simplify
  51        constant_propagation: whether the constant propagation rule should be used
  52
  53    Returns:
  54        sqlglot.Expression: simplified expression
  55    """
  56
  57    dialect = Dialect.get_or_raise(dialect)
  58
  59    def _simplify(expression, root=True):
  60        if expression.meta.get(FINAL):
  61            return expression
  62
  63        # group by expressions cannot be simplified, for example
  64        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  65        # the projection must exactly match the group by key
  66        group = expression.args.get("group")
  67
  68        if group and hasattr(expression, "selects"):
  69            groups = set(group.expressions)
  70            group.meta[FINAL] = True
  71
  72            for e in expression.selects:
  73                for node in e.walk():
  74                    if node in groups:
  75                        e.meta[FINAL] = True
  76                        break
  77
  78            having = expression.args.get("having")
  79            if having:
  80                for node in having.walk():
  81                    if node in groups:
  82                        having.meta[FINAL] = True
  83                        break
  84
  85        # Pre-order transformations
  86        node = expression
  87        node = rewrite_between(node)
  88        node = uniq_sort(node, root)
  89        node = absorb_and_eliminate(node, root)
  90        node = simplify_concat(node)
  91        node = simplify_conditionals(node)
  92
  93        if constant_propagation:
  94            node = propagate_constants(node, root)
  95
  96        exp.replace_children(node, lambda e: _simplify(e, False))
  97
  98        # Post-order transformations
  99        node = simplify_not(node)
 100        node = flatten(node)
 101        node = simplify_connectors(node, root)
 102        node = remove_complements(node, root)
 103        node = simplify_coalesce(node)
 104        node.parent = expression.parent
 105        node = simplify_literals(node, root)
 106        node = simplify_equality(node)
 107        node = simplify_parens(node)
 108        node = simplify_datetrunc(node, dialect)
 109        node = sort_comparison(node)
 110        node = simplify_startswith(node)
 111
 112        if root:
 113            expression.replace(node)
 114        return node
 115
 116    expression = while_changing(expression, _simplify)
 117    remove_where_true(expression)
 118    return expression
 119
 120
 121def catch(*exceptions):
 122    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 123
 124    def decorator(func):
 125        def wrapped(expression, *args, **kwargs):
 126            try:
 127                return func(expression, *args, **kwargs)
 128            except exceptions:
 129                return expression
 130
 131        return wrapped
 132
 133    return decorator
 134
 135
 136def rewrite_between(expression: exp.Expression) -> exp.Expression:
 137    """Rewrite x between y and z to x >= y AND x <= z.
 138
 139    This is done because comparison simplification is only done on lt/lte/gt/gte.
 140    """
 141    if isinstance(expression, exp.Between):
 142        negate = isinstance(expression.parent, exp.Not)
 143
 144        expression = exp.and_(
 145            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 146            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 147            copy=False,
 148        )
 149
 150        if negate:
 151            expression = exp.paren(expression, copy=False)
 152
 153    return expression
 154
 155
 156COMPLEMENT_COMPARISONS = {
 157    exp.LT: exp.GTE,
 158    exp.GT: exp.LTE,
 159    exp.LTE: exp.GT,
 160    exp.GTE: exp.LT,
 161    exp.EQ: exp.NEQ,
 162    exp.NEQ: exp.EQ,
 163}
 164
 165
 166def simplify_not(expression):
 167    """
 168    Demorgan's Law
 169    NOT (x OR y) -> NOT x AND NOT y
 170    NOT (x AND y) -> NOT x OR NOT y
 171    """
 172    if isinstance(expression, exp.Not):
 173        this = expression.this
 174        if is_null(this):
 175            return exp.null()
 176        if this.__class__ in COMPLEMENT_COMPARISONS:
 177            return COMPLEMENT_COMPARISONS[this.__class__](
 178                this=this.this, expression=this.expression
 179            )
 180        if isinstance(this, exp.Paren):
 181            condition = this.unnest()
 182            if isinstance(condition, exp.And):
 183                return exp.paren(
 184                    exp.or_(
 185                        exp.not_(condition.left, copy=False),
 186                        exp.not_(condition.right, copy=False),
 187                        copy=False,
 188                    )
 189                )
 190            if isinstance(condition, exp.Or):
 191                return exp.paren(
 192                    exp.and_(
 193                        exp.not_(condition.left, copy=False),
 194                        exp.not_(condition.right, copy=False),
 195                        copy=False,
 196                    )
 197                )
 198            if is_null(condition):
 199                return exp.null()
 200        if always_true(this):
 201            return exp.false()
 202        if is_false(this):
 203            return exp.true()
 204        if isinstance(this, exp.Not):
 205            # double negation
 206            # NOT NOT x -> x
 207            return this.this
 208    return expression
 209
 210
 211def flatten(expression):
 212    """
 213    A AND (B AND C) -> A AND B AND C
 214    A OR (B OR C) -> A OR B OR C
 215    """
 216    if isinstance(expression, exp.Connector):
 217        for node in expression.args.values():
 218            child = node.unnest()
 219            if isinstance(child, expression.__class__):
 220                node.replace(child)
 221    return expression
 222
 223
 224def simplify_connectors(expression, root=True):
 225    def _simplify_connectors(expression, left, right):
 226        if left == right:
 227            return left
 228        if isinstance(expression, exp.And):
 229            if is_false(left) or is_false(right):
 230                return exp.false()
 231            if is_null(left) or is_null(right):
 232                return exp.null()
 233            if always_true(left) and always_true(right):
 234                return exp.true()
 235            if always_true(left):
 236                return right
 237            if always_true(right):
 238                return left
 239            return _simplify_comparison(expression, left, right)
 240        elif isinstance(expression, exp.Or):
 241            if always_true(left) or always_true(right):
 242                return exp.true()
 243            if is_false(left) and is_false(right):
 244                return exp.false()
 245            if (
 246                (is_null(left) and is_null(right))
 247                or (is_null(left) and is_false(right))
 248                or (is_false(left) and is_null(right))
 249            ):
 250                return exp.null()
 251            if is_false(left):
 252                return right
 253            if is_false(right):
 254                return left
 255            return _simplify_comparison(expression, left, right, or_=True)
 256
 257    if isinstance(expression, exp.Connector):
 258        return _flat_simplify(expression, _simplify_connectors, root)
 259    return expression
 260
 261
 262LT_LTE = (exp.LT, exp.LTE)
 263GT_GTE = (exp.GT, exp.GTE)
 264
 265COMPARISONS = (
 266    *LT_LTE,
 267    *GT_GTE,
 268    exp.EQ,
 269    exp.NEQ,
 270    exp.Is,
 271)
 272
 273INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 274    exp.LT: exp.GT,
 275    exp.GT: exp.LT,
 276    exp.LTE: exp.GTE,
 277    exp.GTE: exp.LTE,
 278}
 279
 280NONDETERMINISTIC = (exp.Rand, exp.Randn)
 281
 282
 283def _simplify_comparison(expression, left, right, or_=False):
 284    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 285        ll, lr = left.args.values()
 286        rl, rr = right.args.values()
 287
 288        largs = {ll, lr}
 289        rargs = {rl, rr}
 290
 291        matching = largs & rargs
 292        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 293
 294        if matching and columns:
 295            try:
 296                l = first(largs - columns)
 297                r = first(rargs - columns)
 298            except StopIteration:
 299                return expression
 300
 301            if l.is_number and r.is_number:
 302                l = float(l.name)
 303                r = float(r.name)
 304            elif l.is_string and r.is_string:
 305                l = l.name
 306                r = r.name
 307            else:
 308                l = extract_date(l)
 309                if not l:
 310                    return None
 311                r = extract_date(r)
 312                if not r:
 313                    return None
 314                # python won't compare date and datetime, but many engines will upcast
 315                l, r = cast_as_datetime(l), cast_as_datetime(r)
 316
 317            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 318                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 319                    return left if (av > bv if or_ else av <= bv) else right
 320                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 321                    return left if (av < bv if or_ else av >= bv) else right
 322
 323                # we can't ever shortcut to true because the column could be null
 324                if not or_:
 325                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 326                        if av <= bv:
 327                            return exp.false()
 328                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 329                        if av >= bv:
 330                            return exp.false()
 331                    elif isinstance(a, exp.EQ):
 332                        if isinstance(b, exp.LT):
 333                            return exp.false() if av >= bv else a
 334                        if isinstance(b, exp.LTE):
 335                            return exp.false() if av > bv else a
 336                        if isinstance(b, exp.GT):
 337                            return exp.false() if av <= bv else a
 338                        if isinstance(b, exp.GTE):
 339                            return exp.false() if av < bv else a
 340                        if isinstance(b, exp.NEQ):
 341                            return exp.false() if av == bv else a
 342    return None
 343
 344
 345def remove_complements(expression, root=True):
 346    """
 347    Removing complements.
 348
 349    A AND NOT A -> FALSE
 350    A OR NOT A -> TRUE
 351    """
 352    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 353        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 354
 355        for a, b in itertools.permutations(expression.flatten(), 2):
 356            if is_complement(a, b):
 357                return complement
 358    return expression
 359
 360
 361def uniq_sort(expression, root=True):
 362    """
 363    Uniq and sort a connector.
 364
 365    C AND A AND B AND B -> A AND B AND C
 366    """
 367    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 368        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 369        flattened = tuple(expression.flatten())
 370        deduped = {gen(e): e for e in flattened}
 371        arr = tuple(deduped.items())
 372
 373        # check if the operands are already sorted, if not sort them
 374        # A AND C AND B -> A AND B AND C
 375        for i, (sql, e) in enumerate(arr[1:]):
 376            if sql < arr[i][0]:
 377                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 378                break
 379        else:
 380            # we didn't have to sort but maybe we need to dedup
 381            if len(deduped) < len(flattened):
 382                expression = result_func(*deduped.values(), copy=False)
 383
 384    return expression
 385
 386
 387def absorb_and_eliminate(expression, root=True):
 388    """
 389    absorption:
 390        A AND (A OR B) -> A
 391        A OR (A AND B) -> A
 392        A AND (NOT A OR B) -> A AND B
 393        A OR (NOT A AND B) -> A OR B
 394    elimination:
 395        (A AND B) OR (A AND NOT B) -> A
 396        (A OR B) AND (A OR NOT B) -> A
 397    """
 398    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 399        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 400
 401        for a, b in itertools.permutations(expression.flatten(), 2):
 402            if isinstance(a, kind):
 403                aa, ab = a.unnest_operands()
 404
 405                # absorb
 406                if is_complement(b, aa):
 407                    aa.replace(exp.true() if kind == exp.And else exp.false())
 408                elif is_complement(b, ab):
 409                    ab.replace(exp.true() if kind == exp.And else exp.false())
 410                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 411                    a.replace(exp.false() if kind == exp.And else exp.true())
 412                elif isinstance(b, kind):
 413                    # eliminate
 414                    rhs = b.unnest_operands()
 415                    ba, bb = rhs
 416
 417                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 418                        a.replace(aa)
 419                        b.replace(aa)
 420                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 421                        a.replace(ab)
 422                        b.replace(ab)
 423
 424    return expression
 425
 426
 427def propagate_constants(expression, root=True):
 428    """
 429    Propagate constants for conjunctions in DNF:
 430
 431    SELECT * FROM t WHERE a = b AND b = 5 becomes
 432    SELECT * FROM t WHERE a = 5 AND b = 5
 433
 434    Reference: https://www.sqlite.org/optoverview.html
 435    """
 436
 437    if (
 438        isinstance(expression, exp.And)
 439        and (root or not expression.same_parent)
 440        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 441    ):
 442        constant_mapping = {}
 443        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 444            if isinstance(expr, exp.EQ):
 445                l, r = expr.left, expr.right
 446
 447                # TODO: create a helper that can be used to detect nested literal expressions such
 448                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 449                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 450                    constant_mapping[l] = (id(l), r)
 451
 452        if constant_mapping:
 453            for column in find_all_in_scope(expression, exp.Column):
 454                parent = column.parent
 455                column_id, constant = constant_mapping.get(column) or (None, None)
 456                if (
 457                    column_id is not None
 458                    and id(column) != column_id
 459                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 460                ):
 461                    column.replace(constant.copy())
 462
 463    return expression
 464
 465
 466INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 467    exp.DateAdd: exp.Sub,
 468    exp.DateSub: exp.Add,
 469    exp.DatetimeAdd: exp.Sub,
 470    exp.DatetimeSub: exp.Add,
 471}
 472
 473INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 474    **INVERSE_DATE_OPS,
 475    exp.Add: exp.Sub,
 476    exp.Sub: exp.Add,
 477}
 478
 479
 480def _is_number(expression: exp.Expression) -> bool:
 481    return expression.is_number
 482
 483
 484def _is_interval(expression: exp.Expression) -> bool:
 485    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 486
 487
 488@catch(ModuleNotFoundError, UnsupportedUnit)
 489def simplify_equality(expression: exp.Expression) -> exp.Expression:
 490    """
 491    Use the subtraction and addition properties of equality to simplify expressions:
 492
 493        x + 1 = 3 becomes x = 2
 494
 495    There are two binary operations in the above expression: + and =
 496    Here's how we reference all the operands in the code below:
 497
 498          l     r
 499        x + 1 = 3
 500        a   b
 501    """
 502    if isinstance(expression, COMPARISONS):
 503        l, r = expression.left, expression.right
 504
 505        if l.__class__ not in INVERSE_OPS:
 506            return expression
 507
 508        if r.is_number:
 509            a_predicate = _is_number
 510            b_predicate = _is_number
 511        elif _is_date_literal(r):
 512            a_predicate = _is_date_literal
 513            b_predicate = _is_interval
 514        else:
 515            return expression
 516
 517        if l.__class__ in INVERSE_DATE_OPS:
 518            l = t.cast(exp.IntervalOp, l)
 519            a = l.this
 520            b = l.interval()
 521        else:
 522            l = t.cast(exp.Binary, l)
 523            a, b = l.left, l.right
 524
 525        if not a_predicate(a) and b_predicate(b):
 526            pass
 527        elif not a_predicate(b) and b_predicate(a):
 528            a, b = b, a
 529        else:
 530            return expression
 531
 532        return expression.__class__(
 533            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 534        )
 535    return expression
 536
 537
 538def simplify_literals(expression, root=True):
 539    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 540        return _flat_simplify(expression, _simplify_binary, root)
 541
 542    if isinstance(expression, exp.Neg):
 543        this = expression.this
 544        if this.is_number:
 545            value = this.name
 546            if value[0] == "-":
 547                return exp.Literal.number(value[1:])
 548            return exp.Literal.number(f"-{value}")
 549
 550    if type(expression) in INVERSE_DATE_OPS:
 551        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 552
 553    return expression
 554
 555
 556NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 557
 558
 559def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 560    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 561        this = _simplify_integer_cast(expr.this)
 562    else:
 563        this = expr.this
 564
 565    if isinstance(expr, exp.Cast) and this.is_int:
 566        num = int(this.name)
 567
 568        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 569        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 570        # engine-dependent
 571        if (
 572            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 573        ) or (
 574            UTINYINT_MIN <= num <= UTINYINT_MAX
 575            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 576        ):
 577            return this
 578
 579    return expr
 580
 581
 582def _simplify_binary(expression, a, b):
 583    if isinstance(expression, COMPARISONS):
 584        a = _simplify_integer_cast(a)
 585        b = _simplify_integer_cast(b)
 586
 587    if isinstance(expression, exp.Is):
 588        if isinstance(b, exp.Not):
 589            c = b.this
 590            not_ = True
 591        else:
 592            c = b
 593            not_ = False
 594
 595        if is_null(c):
 596            if isinstance(a, exp.Literal):
 597                return exp.true() if not_ else exp.false()
 598            if is_null(a):
 599                return exp.false() if not_ else exp.true()
 600    elif isinstance(expression, NULL_OK):
 601        return None
 602    elif is_null(a) or is_null(b):
 603        return exp.null()
 604
 605    if a.is_number and b.is_number:
 606        num_a = int(a.name) if a.is_int else Decimal(a.name)
 607        num_b = int(b.name) if b.is_int else Decimal(b.name)
 608
 609        if isinstance(expression, exp.Add):
 610            return exp.Literal.number(num_a + num_b)
 611        if isinstance(expression, exp.Mul):
 612            return exp.Literal.number(num_a * num_b)
 613
 614        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 615        if isinstance(expression, exp.Sub):
 616            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 617        if isinstance(expression, exp.Div):
 618            # engines have differing int div behavior so intdiv is not safe
 619            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 620                return None
 621            return exp.Literal.number(num_a / num_b)
 622
 623        boolean = eval_boolean(expression, num_a, num_b)
 624
 625        if boolean:
 626            return boolean
 627    elif a.is_string and b.is_string:
 628        boolean = eval_boolean(expression, a.this, b.this)
 629
 630        if boolean:
 631            return boolean
 632    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 633        date, b = extract_date(a), extract_interval(b)
 634        if date and b:
 635            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 636                return date_literal(date + b, extract_type(a))
 637            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 638                return date_literal(date - b, extract_type(a))
 639    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 640        a, date = extract_interval(a), extract_date(b)
 641        # you cannot subtract a date from an interval
 642        if a and b and isinstance(expression, exp.Add):
 643            return date_literal(a + date, extract_type(b))
 644    elif _is_date_literal(a) and _is_date_literal(b):
 645        if isinstance(expression, exp.Predicate):
 646            a, b = extract_date(a), extract_date(b)
 647            boolean = eval_boolean(expression, a, b)
 648            if boolean:
 649                return boolean
 650
 651    return None
 652
 653
 654def simplify_parens(expression):
 655    if not isinstance(expression, exp.Paren):
 656        return expression
 657
 658    this = expression.this
 659    parent = expression.parent
 660    parent_is_predicate = isinstance(parent, exp.Predicate)
 661
 662    if (
 663        not isinstance(this, exp.Select)
 664        and not isinstance(parent, exp.SubqueryPredicate)
 665        and (
 666            not isinstance(parent, (exp.Condition, exp.Binary))
 667            or isinstance(parent, exp.Paren)
 668            or (
 669                not isinstance(this, exp.Binary)
 670                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 671            )
 672            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 673            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 674            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 675            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 676        )
 677    ):
 678        return this
 679    return expression
 680
 681
 682def _is_nonnull_constant(expression: exp.Expression) -> bool:
 683    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 684
 685
 686def _is_constant(expression: exp.Expression) -> bool:
 687    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 688
 689
 690def simplify_coalesce(expression):
 691    # COALESCE(x) -> x
 692    if (
 693        isinstance(expression, exp.Coalesce)
 694        and (not expression.expressions or _is_nonnull_constant(expression.this))
 695        # COALESCE is also used as a Spark partitioning hint
 696        and not isinstance(expression.parent, exp.Hint)
 697    ):
 698        return expression.this
 699
 700    if not isinstance(expression, COMPARISONS):
 701        return expression
 702
 703    if isinstance(expression.left, exp.Coalesce):
 704        coalesce = expression.left
 705        other = expression.right
 706    elif isinstance(expression.right, exp.Coalesce):
 707        coalesce = expression.right
 708        other = expression.left
 709    else:
 710        return expression
 711
 712    # This transformation is valid for non-constants,
 713    # but it really only does anything if they are both constants.
 714    if not _is_constant(other):
 715        return expression
 716
 717    # Find the first constant arg
 718    for arg_index, arg in enumerate(coalesce.expressions):
 719        if _is_constant(arg):
 720            break
 721    else:
 722        return expression
 723
 724    coalesce.set("expressions", coalesce.expressions[:arg_index])
 725
 726    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 727    # since we already remove COALESCE at the top of this function.
 728    coalesce = coalesce if coalesce.expressions else coalesce.this
 729
 730    # This expression is more complex than when we started, but it will get simplified further
 731    return exp.paren(
 732        exp.or_(
 733            exp.and_(
 734                coalesce.is_(exp.null()).not_(copy=False),
 735                expression.copy(),
 736                copy=False,
 737            ),
 738            exp.and_(
 739                coalesce.is_(exp.null()),
 740                type(expression)(this=arg.copy(), expression=other.copy()),
 741                copy=False,
 742            ),
 743            copy=False,
 744        )
 745    )
 746
 747
 748CONCATS = (exp.Concat, exp.DPipe)
 749
 750
 751def simplify_concat(expression):
 752    """Reduces all groups that contain string literals by concatenating them."""
 753    if not isinstance(expression, CONCATS) or (
 754        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 755        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 756    ):
 757        return expression
 758
 759    if isinstance(expression, exp.ConcatWs):
 760        sep_expr, *expressions = expression.expressions
 761        sep = sep_expr.name
 762        concat_type = exp.ConcatWs
 763        args = {}
 764    else:
 765        expressions = expression.expressions
 766        sep = ""
 767        concat_type = exp.Concat
 768        args = {
 769            "safe": expression.args.get("safe"),
 770            "coalesce": expression.args.get("coalesce"),
 771        }
 772
 773    new_args = []
 774    for is_string_group, group in itertools.groupby(
 775        expressions or expression.flatten(), lambda e: e.is_string
 776    ):
 777        if is_string_group:
 778            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 779        else:
 780            new_args.extend(group)
 781
 782    if len(new_args) == 1 and new_args[0].is_string:
 783        return new_args[0]
 784
 785    if concat_type is exp.ConcatWs:
 786        new_args = [sep_expr] + new_args
 787    elif isinstance(expression, exp.DPipe):
 788        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 789
 790    return concat_type(expressions=new_args, **args)
 791
 792
 793def simplify_conditionals(expression):
 794    """Simplifies expressions like IF, CASE if their condition is statically known."""
 795    if isinstance(expression, exp.Case):
 796        this = expression.this
 797        for case in expression.args["ifs"]:
 798            cond = case.this
 799            if this:
 800                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 801                cond = cond.replace(this.pop().eq(cond))
 802
 803            if always_true(cond):
 804                return case.args["true"]
 805
 806            if always_false(cond):
 807                case.pop()
 808                if not expression.args["ifs"]:
 809                    return expression.args.get("default") or exp.null()
 810    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 811        if always_true(expression.this):
 812            return expression.args["true"]
 813        if always_false(expression.this):
 814            return expression.args.get("false") or exp.null()
 815
 816    return expression
 817
 818
 819def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 820    """
 821    Reduces a prefix check to either TRUE or FALSE if both the string and the
 822    prefix are statically known.
 823
 824    Example:
 825        >>> from sqlglot import parse_one
 826        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 827        'TRUE'
 828    """
 829    if (
 830        isinstance(expression, exp.StartsWith)
 831        and expression.this.is_string
 832        and expression.expression.is_string
 833    ):
 834        return exp.convert(expression.name.startswith(expression.expression.name))
 835
 836    return expression
 837
 838
 839DateRange = t.Tuple[datetime.date, datetime.date]
 840
 841
 842def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 843    """
 844    Get the date range for a DATE_TRUNC equality comparison:
 845
 846    Example:
 847        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 848    Returns:
 849        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 850    """
 851    floor = date_floor(date, unit, dialect)
 852
 853    if date != floor:
 854        # This will always be False, except for NULL values.
 855        return None
 856
 857    return floor, floor + interval(unit)
 858
 859
 860def _datetrunc_eq_expression(
 861    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 862) -> exp.Expression:
 863    """Get the logical expression for a date range"""
 864    return exp.and_(
 865        left >= date_literal(drange[0], target_type),
 866        left < date_literal(drange[1], target_type),
 867        copy=False,
 868    )
 869
 870
 871def _datetrunc_eq(
 872    left: exp.Expression,
 873    date: datetime.date,
 874    unit: str,
 875    dialect: Dialect,
 876    target_type: t.Optional[exp.DataType],
 877) -> t.Optional[exp.Expression]:
 878    drange = _datetrunc_range(date, unit, dialect)
 879    if not drange:
 880        return None
 881
 882    return _datetrunc_eq_expression(left, drange, target_type)
 883
 884
 885def _datetrunc_neq(
 886    left: exp.Expression,
 887    date: datetime.date,
 888    unit: str,
 889    dialect: Dialect,
 890    target_type: t.Optional[exp.DataType],
 891) -> t.Optional[exp.Expression]:
 892    drange = _datetrunc_range(date, unit, dialect)
 893    if not drange:
 894        return None
 895
 896    return exp.and_(
 897        left < date_literal(drange[0], target_type),
 898        left >= date_literal(drange[1], target_type),
 899        copy=False,
 900    )
 901
 902
 903DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 904    exp.LT: lambda l, dt, u, d, t: l
 905    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 906    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 907    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 908    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 909    exp.EQ: _datetrunc_eq,
 910    exp.NEQ: _datetrunc_neq,
 911}
 912DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 913DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 914
 915
 916def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 917    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 918
 919
 920@catch(ModuleNotFoundError, UnsupportedUnit)
 921def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 922    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 923    comparison = expression.__class__
 924
 925    if isinstance(expression, DATETRUNCS):
 926        this = expression.this
 927        trunc_type = extract_type(this)
 928        date = extract_date(this)
 929        if date and expression.unit:
 930            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
 931    elif comparison not in DATETRUNC_COMPARISONS:
 932        return expression
 933
 934    if isinstance(expression, exp.Binary):
 935        l, r = expression.left, expression.right
 936
 937        if not _is_datetrunc_predicate(l, r):
 938            return expression
 939
 940        l = t.cast(exp.DateTrunc, l)
 941        trunc_arg = l.this
 942        unit = l.unit.name.lower()
 943        date = extract_date(r)
 944
 945        if not date:
 946            return expression
 947
 948        return (
 949            DATETRUNC_BINARY_COMPARISONS[comparison](
 950                trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
 951            )
 952            or expression
 953        )
 954
 955    if isinstance(expression, exp.In):
 956        l = expression.this
 957        rs = expression.expressions
 958
 959        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 960            l = t.cast(exp.DateTrunc, l)
 961            unit = l.unit.name.lower()
 962
 963            ranges = []
 964            for r in rs:
 965                date = extract_date(r)
 966                if not date:
 967                    return expression
 968                drange = _datetrunc_range(date, unit, dialect)
 969                if drange:
 970                    ranges.append(drange)
 971
 972            if not ranges:
 973                return expression
 974
 975            ranges = merge_ranges(ranges)
 976            target_type = extract_type(l, *rs)
 977
 978            return exp.or_(
 979                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
 980            )
 981
 982    return expression
 983
 984
 985def sort_comparison(expression: exp.Expression) -> exp.Expression:
 986    if expression.__class__ in COMPLEMENT_COMPARISONS:
 987        l, r = expression.this, expression.expression
 988        l_column = isinstance(l, exp.Column)
 989        r_column = isinstance(r, exp.Column)
 990        l_const = _is_constant(l)
 991        r_const = _is_constant(r)
 992
 993        if (l_column and not r_column) or (r_const and not l_const):
 994            return expression
 995        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 996            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 997                this=r, expression=l
 998            )
 999    return expression
1000
1001
1002# CROSS joins result in an empty table if the right table is empty.
1003# So we can only simplify certain types of joins to CROSS.
1004# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1005JOINS = {
1006    ("", ""),
1007    ("", "INNER"),
1008    ("RIGHT", ""),
1009    ("RIGHT", "OUTER"),
1010}
1011
1012
1013def remove_where_true(expression):
1014    for where in expression.find_all(exp.Where):
1015        if always_true(where.this):
1016            where.pop()
1017    for join in expression.find_all(exp.Join):
1018        if (
1019            always_true(join.args.get("on"))
1020            and not join.args.get("using")
1021            and not join.args.get("method")
1022            and (join.side, join.kind) in JOINS
1023        ):
1024            join.args["on"].pop()
1025            join.set("side", None)
1026            join.set("kind", "CROSS")
1027
1028
1029def always_true(expression):
1030    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1031        expression, exp.Literal
1032    )
1033
1034
1035def always_false(expression):
1036    return is_false(expression) or is_null(expression)
1037
1038
1039def is_complement(a, b):
1040    return isinstance(b, exp.Not) and b.this == a
1041
1042
1043def is_false(a: exp.Expression) -> bool:
1044    return type(a) is exp.Boolean and not a.this
1045
1046
1047def is_null(a: exp.Expression) -> bool:
1048    return type(a) is exp.Null
1049
1050
1051def eval_boolean(expression, a, b):
1052    if isinstance(expression, (exp.EQ, exp.Is)):
1053        return boolean_literal(a == b)
1054    if isinstance(expression, exp.NEQ):
1055        return boolean_literal(a != b)
1056    if isinstance(expression, exp.GT):
1057        return boolean_literal(a > b)
1058    if isinstance(expression, exp.GTE):
1059        return boolean_literal(a >= b)
1060    if isinstance(expression, exp.LT):
1061        return boolean_literal(a < b)
1062    if isinstance(expression, exp.LTE):
1063        return boolean_literal(a <= b)
1064    return None
1065
1066
1067def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1068    if isinstance(value, datetime.datetime):
1069        return value.date()
1070    if isinstance(value, datetime.date):
1071        return value
1072    try:
1073        return datetime.datetime.fromisoformat(value).date()
1074    except ValueError:
1075        return None
1076
1077
1078def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1079    if isinstance(value, datetime.datetime):
1080        return value
1081    if isinstance(value, datetime.date):
1082        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1083    try:
1084        return datetime.datetime.fromisoformat(value)
1085    except ValueError:
1086        return None
1087
1088
1089def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1090    if not value:
1091        return None
1092    if to.is_type(exp.DataType.Type.DATE):
1093        return cast_as_date(value)
1094    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1095        return cast_as_datetime(value)
1096    return None
1097
1098
1099def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1100    if isinstance(cast, exp.Cast):
1101        to = cast.to
1102    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1103        to = exp.DataType.build(exp.DataType.Type.DATE)
1104    else:
1105        return None
1106
1107    if isinstance(cast.this, exp.Literal):
1108        value: t.Any = cast.this.name
1109    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1110        value = extract_date(cast.this)
1111    else:
1112        return None
1113    return cast_value(value, to)
1114
1115
1116def _is_date_literal(expression: exp.Expression) -> bool:
1117    return extract_date(expression) is not None
1118
1119
1120def extract_interval(expression):
1121    try:
1122        n = int(expression.name)
1123        unit = expression.text("unit").lower()
1124        return interval(unit, n)
1125    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1126        return None
1127
1128
1129def extract_type(*expressions):
1130    target_type = None
1131    for expression in expressions:
1132        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1133        if target_type:
1134            break
1135
1136    return target_type
1137
1138
1139def date_literal(date, target_type=None):
1140    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1141        target_type = (
1142            exp.DataType.Type.DATETIME
1143            if isinstance(date, datetime.datetime)
1144            else exp.DataType.Type.DATE
1145        )
1146
1147    return exp.cast(exp.Literal.string(date), target_type)
1148
1149
1150def interval(unit: str, n: int = 1):
1151    from dateutil.relativedelta import relativedelta
1152
1153    if unit == "year":
1154        return relativedelta(years=1 * n)
1155    if unit == "quarter":
1156        return relativedelta(months=3 * n)
1157    if unit == "month":
1158        return relativedelta(months=1 * n)
1159    if unit == "week":
1160        return relativedelta(weeks=1 * n)
1161    if unit == "day":
1162        return relativedelta(days=1 * n)
1163    if unit == "hour":
1164        return relativedelta(hours=1 * n)
1165    if unit == "minute":
1166        return relativedelta(minutes=1 * n)
1167    if unit == "second":
1168        return relativedelta(seconds=1 * n)
1169
1170    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1171
1172
1173def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1174    if unit == "year":
1175        return d.replace(month=1, day=1)
1176    if unit == "quarter":
1177        if d.month <= 3:
1178            return d.replace(month=1, day=1)
1179        elif d.month <= 6:
1180            return d.replace(month=4, day=1)
1181        elif d.month <= 9:
1182            return d.replace(month=7, day=1)
1183        else:
1184            return d.replace(month=10, day=1)
1185    if unit == "month":
1186        return d.replace(month=d.month, day=1)
1187    if unit == "week":
1188        # Assuming week starts on Monday (0) and ends on Sunday (6)
1189        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1190    if unit == "day":
1191        return d
1192
1193    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1194
1195
1196def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1197    floor = date_floor(d, unit, dialect)
1198
1199    if floor == d:
1200        return d
1201
1202    return floor + interval(unit)
1203
1204
1205def boolean_literal(condition):
1206    return exp.true() if condition else exp.false()
1207
1208
1209def _flat_simplify(expression, simplifier, root=True):
1210    if root or not expression.same_parent:
1211        operands = []
1212        queue = deque(expression.flatten(unnest=False))
1213        size = len(queue)
1214
1215        while queue:
1216            a = queue.popleft()
1217
1218            for b in queue:
1219                result = simplifier(expression, a, b)
1220
1221                if result and result is not expression:
1222                    queue.remove(b)
1223                    queue.appendleft(result)
1224                    break
1225            else:
1226                operands.append(a)
1227
1228        if len(operands) < size:
1229            return functools.reduce(
1230                lambda a, b: expression.__class__(this=a, expression=b), operands
1231            )
1232    return expression
1233
1234
1235def gen(expression: t.Any) -> str:
1236    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1237
1238    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1239    generator is expensive so we have a bare minimum sql generator here.
1240    """
1241    return Gen().gen(expression)
1242
1243
1244class Gen:
1245    def __init__(self):
1246        self.stack = []
1247        self.sqls = []
1248
1249    def gen(self, expression: exp.Expression) -> str:
1250        self.stack = [expression]
1251        self.sqls.clear()
1252
1253        while self.stack:
1254            node = self.stack.pop()
1255
1256            if isinstance(node, exp.Expression):
1257                exp_handler_name = f"{node.key}_sql"
1258
1259                if hasattr(self, exp_handler_name):
1260                    getattr(self, exp_handler_name)(node)
1261                elif isinstance(node, exp.Func):
1262                    self._function(node)
1263                else:
1264                    key = node.key.upper()
1265                    self.stack.append(f"{key} " if self._args(node) else key)
1266            elif type(node) is list:
1267                for n in reversed(node):
1268                    if n is not None:
1269                        self.stack.extend((n, ","))
1270                if node:
1271                    self.stack.pop()
1272            else:
1273                if node is not None:
1274                    self.sqls.append(str(node))
1275
1276        return "".join(self.sqls)
1277
1278    def add_sql(self, e: exp.Add) -> None:
1279        self._binary(e, " + ")
1280
1281    def alias_sql(self, e: exp.Alias) -> None:
1282        self.stack.extend(
1283            (
1284                e.args.get("alias"),
1285                " AS ",
1286                e.args.get("this"),
1287            )
1288        )
1289
1290    def and_sql(self, e: exp.And) -> None:
1291        self._binary(e, " AND ")
1292
1293    def anonymous_sql(self, e: exp.Anonymous) -> None:
1294        this = e.this
1295        if isinstance(this, str):
1296            name = this.upper()
1297        elif isinstance(this, exp.Identifier):
1298            name = this.this
1299            name = f'"{name}"' if this.quoted else name.upper()
1300        else:
1301            raise ValueError(
1302                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1303            )
1304
1305        self.stack.extend(
1306            (
1307                ")",
1308                e.expressions,
1309                "(",
1310                name,
1311            )
1312        )
1313
1314    def between_sql(self, e: exp.Between) -> None:
1315        self.stack.extend(
1316            (
1317                e.args.get("high"),
1318                " AND ",
1319                e.args.get("low"),
1320                " BETWEEN ",
1321                e.this,
1322            )
1323        )
1324
1325    def boolean_sql(self, e: exp.Boolean) -> None:
1326        self.stack.append("TRUE" if e.this else "FALSE")
1327
1328    def bracket_sql(self, e: exp.Bracket) -> None:
1329        self.stack.extend(
1330            (
1331                "]",
1332                e.expressions,
1333                "[",
1334                e.this,
1335            )
1336        )
1337
1338    def column_sql(self, e: exp.Column) -> None:
1339        for p in reversed(e.parts):
1340            self.stack.extend((p, "."))
1341        self.stack.pop()
1342
1343    def datatype_sql(self, e: exp.DataType) -> None:
1344        self._args(e, 1)
1345        self.stack.append(f"{e.this.name} ")
1346
1347    def div_sql(self, e: exp.Div) -> None:
1348        self._binary(e, " / ")
1349
1350    def dot_sql(self, e: exp.Dot) -> None:
1351        self._binary(e, ".")
1352
1353    def eq_sql(self, e: exp.EQ) -> None:
1354        self._binary(e, " = ")
1355
1356    def from_sql(self, e: exp.From) -> None:
1357        self.stack.extend((e.this, "FROM "))
1358
1359    def gt_sql(self, e: exp.GT) -> None:
1360        self._binary(e, " > ")
1361
1362    def gte_sql(self, e: exp.GTE) -> None:
1363        self._binary(e, " >= ")
1364
1365    def identifier_sql(self, e: exp.Identifier) -> None:
1366        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1367
1368    def ilike_sql(self, e: exp.ILike) -> None:
1369        self._binary(e, " ILIKE ")
1370
1371    def in_sql(self, e: exp.In) -> None:
1372        self.stack.append(")")
1373        self._args(e, 1)
1374        self.stack.extend(
1375            (
1376                "(",
1377                " IN ",
1378                e.this,
1379            )
1380        )
1381
1382    def intdiv_sql(self, e: exp.IntDiv) -> None:
1383        self._binary(e, " DIV ")
1384
1385    def is_sql(self, e: exp.Is) -> None:
1386        self._binary(e, " IS ")
1387
1388    def like_sql(self, e: exp.Like) -> None:
1389        self._binary(e, " Like ")
1390
1391    def literal_sql(self, e: exp.Literal) -> None:
1392        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1393
1394    def lt_sql(self, e: exp.LT) -> None:
1395        self._binary(e, " < ")
1396
1397    def lte_sql(self, e: exp.LTE) -> None:
1398        self._binary(e, " <= ")
1399
1400    def mod_sql(self, e: exp.Mod) -> None:
1401        self._binary(e, " % ")
1402
1403    def mul_sql(self, e: exp.Mul) -> None:
1404        self._binary(e, " * ")
1405
1406    def neg_sql(self, e: exp.Neg) -> None:
1407        self._unary(e, "-")
1408
1409    def neq_sql(self, e: exp.NEQ) -> None:
1410        self._binary(e, " <> ")
1411
1412    def not_sql(self, e: exp.Not) -> None:
1413        self._unary(e, "NOT ")
1414
1415    def null_sql(self, e: exp.Null) -> None:
1416        self.stack.append("NULL")
1417
1418    def or_sql(self, e: exp.Or) -> None:
1419        self._binary(e, " OR ")
1420
1421    def paren_sql(self, e: exp.Paren) -> None:
1422        self.stack.extend(
1423            (
1424                ")",
1425                e.this,
1426                "(",
1427            )
1428        )
1429
1430    def sub_sql(self, e: exp.Sub) -> None:
1431        self._binary(e, " - ")
1432
1433    def subquery_sql(self, e: exp.Subquery) -> None:
1434        self._args(e, 2)
1435        alias = e.args.get("alias")
1436        if alias:
1437            self.stack.append(alias)
1438        self.stack.extend((")", e.this, "("))
1439
1440    def table_sql(self, e: exp.Table) -> None:
1441        self._args(e, 4)
1442        alias = e.args.get("alias")
1443        if alias:
1444            self.stack.append(alias)
1445        for p in reversed(e.parts):
1446            self.stack.extend((p, "."))
1447        self.stack.pop()
1448
1449    def tablealias_sql(self, e: exp.TableAlias) -> None:
1450        columns = e.columns
1451
1452        if columns:
1453            self.stack.extend((")", columns, "("))
1454
1455        self.stack.extend((e.this, " AS "))
1456
1457    def var_sql(self, e: exp.Var) -> None:
1458        self.stack.append(e.this)
1459
1460    def _binary(self, e: exp.Binary, op: str) -> None:
1461        self.stack.extend((e.expression, op, e.this))
1462
1463    def _unary(self, e: exp.Unary, op: str) -> None:
1464        self.stack.extend((e.this, op))
1465
1466    def _function(self, e: exp.Func) -> None:
1467        self.stack.extend(
1468            (
1469                ")",
1470                list(e.args.values()),
1471                "(",
1472                e.sql_name(),
1473            )
1474        )
1475
1476    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1477        kvs = []
1478        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1479
1480        for k in arg_types or arg_types:
1481            v = node.args.get(k)
1482
1483            if v is not None:
1484                kvs.append([f":{k}", v])
1485        if kvs:
1486            self.stack.append(kvs)
1487            return True
1488        return False
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
34class UnsupportedUnit(Exception):
35    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 38def simplify(
 39    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 40):
 41    """
 42    Rewrite sqlglot AST to simplify expressions.
 43
 44    Example:
 45        >>> import sqlglot
 46        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 47        >>> simplify(expression).sql()
 48        'TRUE'
 49
 50    Args:
 51        expression (sqlglot.Expression): expression to simplify
 52        constant_propagation: whether the constant propagation rule should be used
 53
 54    Returns:
 55        sqlglot.Expression: simplified expression
 56    """
 57
 58    dialect = Dialect.get_or_raise(dialect)
 59
 60    def _simplify(expression, root=True):
 61        if expression.meta.get(FINAL):
 62            return expression
 63
 64        # group by expressions cannot be simplified, for example
 65        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 66        # the projection must exactly match the group by key
 67        group = expression.args.get("group")
 68
 69        if group and hasattr(expression, "selects"):
 70            groups = set(group.expressions)
 71            group.meta[FINAL] = True
 72
 73            for e in expression.selects:
 74                for node in e.walk():
 75                    if node in groups:
 76                        e.meta[FINAL] = True
 77                        break
 78
 79            having = expression.args.get("having")
 80            if having:
 81                for node in having.walk():
 82                    if node in groups:
 83                        having.meta[FINAL] = True
 84                        break
 85
 86        # Pre-order transformations
 87        node = expression
 88        node = rewrite_between(node)
 89        node = uniq_sort(node, root)
 90        node = absorb_and_eliminate(node, root)
 91        node = simplify_concat(node)
 92        node = simplify_conditionals(node)
 93
 94        if constant_propagation:
 95            node = propagate_constants(node, root)
 96
 97        exp.replace_children(node, lambda e: _simplify(e, False))
 98
 99        # Post-order transformations
100        node = simplify_not(node)
101        node = flatten(node)
102        node = simplify_connectors(node, root)
103        node = remove_complements(node, root)
104        node = simplify_coalesce(node)
105        node.parent = expression.parent
106        node = simplify_literals(node, root)
107        node = simplify_equality(node)
108        node = simplify_parens(node)
109        node = simplify_datetrunc(node, dialect)
110        node = sort_comparison(node)
111        node = simplify_startswith(node)
112
113        if root:
114            expression.replace(node)
115        return node
116
117    expression = while_changing(expression, _simplify)
118    remove_where_true(expression)
119    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
122def catch(*exceptions):
123    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
124
125    def decorator(func):
126        def wrapped(expression, *args, **kwargs):
127            try:
128                return func(expression, *args, **kwargs)
129            except exceptions:
130                return expression
131
132        return wrapped
133
134    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
137def rewrite_between(expression: exp.Expression) -> exp.Expression:
138    """Rewrite x between y and z to x >= y AND x <= z.
139
140    This is done because comparison simplification is only done on lt/lte/gt/gte.
141    """
142    if isinstance(expression, exp.Between):
143        negate = isinstance(expression.parent, exp.Not)
144
145        expression = exp.and_(
146            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
147            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
148            copy=False,
149        )
150
151        if negate:
152            expression = exp.paren(expression, copy=False)
153
154    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
167def simplify_not(expression):
168    """
169    Demorgan's Law
170    NOT (x OR y) -> NOT x AND NOT y
171    NOT (x AND y) -> NOT x OR NOT y
172    """
173    if isinstance(expression, exp.Not):
174        this = expression.this
175        if is_null(this):
176            return exp.null()
177        if this.__class__ in COMPLEMENT_COMPARISONS:
178            return COMPLEMENT_COMPARISONS[this.__class__](
179                this=this.this, expression=this.expression
180            )
181        if isinstance(this, exp.Paren):
182            condition = this.unnest()
183            if isinstance(condition, exp.And):
184                return exp.paren(
185                    exp.or_(
186                        exp.not_(condition.left, copy=False),
187                        exp.not_(condition.right, copy=False),
188                        copy=False,
189                    )
190                )
191            if isinstance(condition, exp.Or):
192                return exp.paren(
193                    exp.and_(
194                        exp.not_(condition.left, copy=False),
195                        exp.not_(condition.right, copy=False),
196                        copy=False,
197                    )
198                )
199            if is_null(condition):
200                return exp.null()
201        if always_true(this):
202            return exp.false()
203        if is_false(this):
204            return exp.true()
205        if isinstance(this, exp.Not):
206            # double negation
207            # NOT NOT x -> x
208            return this.this
209    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
212def flatten(expression):
213    """
214    A AND (B AND C) -> A AND B AND C
215    A OR (B OR C) -> A OR B OR C
216    """
217    if isinstance(expression, exp.Connector):
218        for node in expression.args.values():
219            child = node.unnest()
220            if isinstance(child, expression.__class__):
221                node.replace(child)
222    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
225def simplify_connectors(expression, root=True):
226    def _simplify_connectors(expression, left, right):
227        if left == right:
228            return left
229        if isinstance(expression, exp.And):
230            if is_false(left) or is_false(right):
231                return exp.false()
232            if is_null(left) or is_null(right):
233                return exp.null()
234            if always_true(left) and always_true(right):
235                return exp.true()
236            if always_true(left):
237                return right
238            if always_true(right):
239                return left
240            return _simplify_comparison(expression, left, right)
241        elif isinstance(expression, exp.Or):
242            if always_true(left) or always_true(right):
243                return exp.true()
244            if is_false(left) and is_false(right):
245                return exp.false()
246            if (
247                (is_null(left) and is_null(right))
248                or (is_null(left) and is_false(right))
249                or (is_false(left) and is_null(right))
250            ):
251                return exp.null()
252            if is_false(left):
253                return right
254            if is_false(right):
255                return left
256            return _simplify_comparison(expression, left, right, or_=True)
257
258    if isinstance(expression, exp.Connector):
259        return _flat_simplify(expression, _simplify_connectors, root)
260    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
346def remove_complements(expression, root=True):
347    """
348    Removing complements.
349
350    A AND NOT A -> FALSE
351    A OR NOT A -> TRUE
352    """
353    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
354        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
355
356        for a, b in itertools.permutations(expression.flatten(), 2):
357            if is_complement(a, b):
358                return complement
359    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
362def uniq_sort(expression, root=True):
363    """
364    Uniq and sort a connector.
365
366    C AND A AND B AND B -> A AND B AND C
367    """
368    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
369        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
370        flattened = tuple(expression.flatten())
371        deduped = {gen(e): e for e in flattened}
372        arr = tuple(deduped.items())
373
374        # check if the operands are already sorted, if not sort them
375        # A AND C AND B -> A AND B AND C
376        for i, (sql, e) in enumerate(arr[1:]):
377            if sql < arr[i][0]:
378                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
379                break
380        else:
381            # we didn't have to sort but maybe we need to dedup
382            if len(deduped) < len(flattened):
383                expression = result_func(*deduped.values(), copy=False)
384
385    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
388def absorb_and_eliminate(expression, root=True):
389    """
390    absorption:
391        A AND (A OR B) -> A
392        A OR (A AND B) -> A
393        A AND (NOT A OR B) -> A AND B
394        A OR (NOT A AND B) -> A OR B
395    elimination:
396        (A AND B) OR (A AND NOT B) -> A
397        (A OR B) AND (A OR NOT B) -> A
398    """
399    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
400        kind = exp.Or if isinstance(expression, exp.And) else exp.And
401
402        for a, b in itertools.permutations(expression.flatten(), 2):
403            if isinstance(a, kind):
404                aa, ab = a.unnest_operands()
405
406                # absorb
407                if is_complement(b, aa):
408                    aa.replace(exp.true() if kind == exp.And else exp.false())
409                elif is_complement(b, ab):
410                    ab.replace(exp.true() if kind == exp.And else exp.false())
411                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
412                    a.replace(exp.false() if kind == exp.And else exp.true())
413                elif isinstance(b, kind):
414                    # eliminate
415                    rhs = b.unnest_operands()
416                    ba, bb = rhs
417
418                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
419                        a.replace(aa)
420                        b.replace(aa)
421                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
422                        a.replace(ab)
423                        b.replace(ab)
424
425    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
428def propagate_constants(expression, root=True):
429    """
430    Propagate constants for conjunctions in DNF:
431
432    SELECT * FROM t WHERE a = b AND b = 5 becomes
433    SELECT * FROM t WHERE a = 5 AND b = 5
434
435    Reference: https://www.sqlite.org/optoverview.html
436    """
437
438    if (
439        isinstance(expression, exp.And)
440        and (root or not expression.same_parent)
441        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
442    ):
443        constant_mapping = {}
444        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
445            if isinstance(expr, exp.EQ):
446                l, r = expr.left, expr.right
447
448                # TODO: create a helper that can be used to detect nested literal expressions such
449                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
450                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
451                    constant_mapping[l] = (id(l), r)
452
453        if constant_mapping:
454            for column in find_all_in_scope(expression, exp.Column):
455                parent = column.parent
456                column_id, constant = constant_mapping.get(column) or (None, None)
457                if (
458                    column_id is not None
459                    and id(column) != column_id
460                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
461                ):
462                    column.replace(constant.copy())
463
464    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
126        def wrapped(expression, *args, **kwargs):
127            try:
128                return func(expression, *args, **kwargs)
129            except exceptions:
130                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
539def simplify_literals(expression, root=True):
540    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
541        return _flat_simplify(expression, _simplify_binary, root)
542
543    if isinstance(expression, exp.Neg):
544        this = expression.this
545        if this.is_number:
546            value = this.name
547            if value[0] == "-":
548                return exp.Literal.number(value[1:])
549            return exp.Literal.number(f"-{value}")
550
551    if type(expression) in INVERSE_DATE_OPS:
552        return _simplify_binary(expression, expression.this, expression.interval()) or expression
553
554    return expression
def simplify_parens(expression):
655def simplify_parens(expression):
656    if not isinstance(expression, exp.Paren):
657        return expression
658
659    this = expression.this
660    parent = expression.parent
661    parent_is_predicate = isinstance(parent, exp.Predicate)
662
663    if (
664        not isinstance(this, exp.Select)
665        and not isinstance(parent, exp.SubqueryPredicate)
666        and (
667            not isinstance(parent, (exp.Condition, exp.Binary))
668            or isinstance(parent, exp.Paren)
669            or (
670                not isinstance(this, exp.Binary)
671                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
672            )
673            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
674            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
675            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
676            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
677        )
678    ):
679        return this
680    return expression
def simplify_coalesce(expression):
691def simplify_coalesce(expression):
692    # COALESCE(x) -> x
693    if (
694        isinstance(expression, exp.Coalesce)
695        and (not expression.expressions or _is_nonnull_constant(expression.this))
696        # COALESCE is also used as a Spark partitioning hint
697        and not isinstance(expression.parent, exp.Hint)
698    ):
699        return expression.this
700
701    if not isinstance(expression, COMPARISONS):
702        return expression
703
704    if isinstance(expression.left, exp.Coalesce):
705        coalesce = expression.left
706        other = expression.right
707    elif isinstance(expression.right, exp.Coalesce):
708        coalesce = expression.right
709        other = expression.left
710    else:
711        return expression
712
713    # This transformation is valid for non-constants,
714    # but it really only does anything if they are both constants.
715    if not _is_constant(other):
716        return expression
717
718    # Find the first constant arg
719    for arg_index, arg in enumerate(coalesce.expressions):
720        if _is_constant(arg):
721            break
722    else:
723        return expression
724
725    coalesce.set("expressions", coalesce.expressions[:arg_index])
726
727    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
728    # since we already remove COALESCE at the top of this function.
729    coalesce = coalesce if coalesce.expressions else coalesce.this
730
731    # This expression is more complex than when we started, but it will get simplified further
732    return exp.paren(
733        exp.or_(
734            exp.and_(
735                coalesce.is_(exp.null()).not_(copy=False),
736                expression.copy(),
737                copy=False,
738            ),
739            exp.and_(
740                coalesce.is_(exp.null()),
741                type(expression)(this=arg.copy(), expression=other.copy()),
742                copy=False,
743            ),
744            copy=False,
745        )
746    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
752def simplify_concat(expression):
753    """Reduces all groups that contain string literals by concatenating them."""
754    if not isinstance(expression, CONCATS) or (
755        # We can't reduce a CONCAT_WS call if we don't statically know the separator
756        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
757    ):
758        return expression
759
760    if isinstance(expression, exp.ConcatWs):
761        sep_expr, *expressions = expression.expressions
762        sep = sep_expr.name
763        concat_type = exp.ConcatWs
764        args = {}
765    else:
766        expressions = expression.expressions
767        sep = ""
768        concat_type = exp.Concat
769        args = {
770            "safe": expression.args.get("safe"),
771            "coalesce": expression.args.get("coalesce"),
772        }
773
774    new_args = []
775    for is_string_group, group in itertools.groupby(
776        expressions or expression.flatten(), lambda e: e.is_string
777    ):
778        if is_string_group:
779            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
780        else:
781            new_args.extend(group)
782
783    if len(new_args) == 1 and new_args[0].is_string:
784        return new_args[0]
785
786    if concat_type is exp.ConcatWs:
787        new_args = [sep_expr] + new_args
788    elif isinstance(expression, exp.DPipe):
789        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
790
791    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
794def simplify_conditionals(expression):
795    """Simplifies expressions like IF, CASE if their condition is statically known."""
796    if isinstance(expression, exp.Case):
797        this = expression.this
798        for case in expression.args["ifs"]:
799            cond = case.this
800            if this:
801                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
802                cond = cond.replace(this.pop().eq(cond))
803
804            if always_true(cond):
805                return case.args["true"]
806
807            if always_false(cond):
808                case.pop()
809                if not expression.args["ifs"]:
810                    return expression.args.get("default") or exp.null()
811    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
812        if always_true(expression.this):
813            return expression.args["true"]
814        if always_false(expression.this):
815            return expression.args.get("false") or exp.null()
816
817    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
820def simplify_startswith(expression: exp.Expression) -> exp.Expression:
821    """
822    Reduces a prefix check to either TRUE or FALSE if both the string and the
823    prefix are statically known.
824
825    Example:
826        >>> from sqlglot import parse_one
827        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
828        'TRUE'
829    """
830    if (
831        isinstance(expression, exp.StartsWith)
832        and expression.this.is_string
833        and expression.expression.is_string
834    ):
835        return exp.convert(expression.name.startswith(expression.expression.name))
836
837    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
126        def wrapped(expression, *args, **kwargs):
127            try:
128                return func(expression, *args, **kwargs)
129            except exceptions:
130                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 986def sort_comparison(expression: exp.Expression) -> exp.Expression:
 987    if expression.__class__ in COMPLEMENT_COMPARISONS:
 988        l, r = expression.this, expression.expression
 989        l_column = isinstance(l, exp.Column)
 990        r_column = isinstance(r, exp.Column)
 991        l_const = _is_constant(l)
 992        r_const = _is_constant(r)
 993
 994        if (l_column and not r_column) or (r_const and not l_const):
 995            return expression
 996        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 997            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 998                this=r, expression=l
 999            )
1000    return expression
JOINS = {('', ''), ('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
1014def remove_where_true(expression):
1015    for where in expression.find_all(exp.Where):
1016        if always_true(where.this):
1017            where.pop()
1018    for join in expression.find_all(exp.Join):
1019        if (
1020            always_true(join.args.get("on"))
1021            and not join.args.get("using")
1022            and not join.args.get("method")
1023            and (join.side, join.kind) in JOINS
1024        ):
1025            join.args["on"].pop()
1026            join.set("side", None)
1027            join.set("kind", "CROSS")
def always_true(expression):
1030def always_true(expression):
1031    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1032        expression, exp.Literal
1033    )
def always_false(expression):
1036def always_false(expression):
1037    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1040def is_complement(a, b):
1041    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1044def is_false(a: exp.Expression) -> bool:
1045    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1048def is_null(a: exp.Expression) -> bool:
1049    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1052def eval_boolean(expression, a, b):
1053    if isinstance(expression, (exp.EQ, exp.Is)):
1054        return boolean_literal(a == b)
1055    if isinstance(expression, exp.NEQ):
1056        return boolean_literal(a != b)
1057    if isinstance(expression, exp.GT):
1058        return boolean_literal(a > b)
1059    if isinstance(expression, exp.GTE):
1060        return boolean_literal(a >= b)
1061    if isinstance(expression, exp.LT):
1062        return boolean_literal(a < b)
1063    if isinstance(expression, exp.LTE):
1064        return boolean_literal(a <= b)
1065    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1068def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1069    if isinstance(value, datetime.datetime):
1070        return value.date()
1071    if isinstance(value, datetime.date):
1072        return value
1073    try:
1074        return datetime.datetime.fromisoformat(value).date()
1075    except ValueError:
1076        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1079def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1080    if isinstance(value, datetime.datetime):
1081        return value
1082    if isinstance(value, datetime.date):
1083        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1084    try:
1085        return datetime.datetime.fromisoformat(value)
1086    except ValueError:
1087        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1090def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1091    if not value:
1092        return None
1093    if to.is_type(exp.DataType.Type.DATE):
1094        return cast_as_date(value)
1095    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1096        return cast_as_datetime(value)
1097    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1100def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1101    if isinstance(cast, exp.Cast):
1102        to = cast.to
1103    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1104        to = exp.DataType.build(exp.DataType.Type.DATE)
1105    else:
1106        return None
1107
1108    if isinstance(cast.this, exp.Literal):
1109        value: t.Any = cast.this.name
1110    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1111        value = extract_date(cast.this)
1112    else:
1113        return None
1114    return cast_value(value, to)
def extract_interval(expression):
1121def extract_interval(expression):
1122    try:
1123        n = int(expression.name)
1124        unit = expression.text("unit").lower()
1125        return interval(unit, n)
1126    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1127        return None
def extract_type(*expressions):
1130def extract_type(*expressions):
1131    target_type = None
1132    for expression in expressions:
1133        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1134        if target_type:
1135            break
1136
1137    return target_type
def date_literal(date, target_type=None):
1140def date_literal(date, target_type=None):
1141    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1142        target_type = (
1143            exp.DataType.Type.DATETIME
1144            if isinstance(date, datetime.datetime)
1145            else exp.DataType.Type.DATE
1146        )
1147
1148    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1151def interval(unit: str, n: int = 1):
1152    from dateutil.relativedelta import relativedelta
1153
1154    if unit == "year":
1155        return relativedelta(years=1 * n)
1156    if unit == "quarter":
1157        return relativedelta(months=3 * n)
1158    if unit == "month":
1159        return relativedelta(months=1 * n)
1160    if unit == "week":
1161        return relativedelta(weeks=1 * n)
1162    if unit == "day":
1163        return relativedelta(days=1 * n)
1164    if unit == "hour":
1165        return relativedelta(hours=1 * n)
1166    if unit == "minute":
1167        return relativedelta(minutes=1 * n)
1168    if unit == "second":
1169        return relativedelta(seconds=1 * n)
1170
1171    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1174def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1175    if unit == "year":
1176        return d.replace(month=1, day=1)
1177    if unit == "quarter":
1178        if d.month <= 3:
1179            return d.replace(month=1, day=1)
1180        elif d.month <= 6:
1181            return d.replace(month=4, day=1)
1182        elif d.month <= 9:
1183            return d.replace(month=7, day=1)
1184        else:
1185            return d.replace(month=10, day=1)
1186    if unit == "month":
1187        return d.replace(month=d.month, day=1)
1188    if unit == "week":
1189        # Assuming week starts on Monday (0) and ends on Sunday (6)
1190        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1191    if unit == "day":
1192        return d
1193
1194    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1197def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1198    floor = date_floor(d, unit, dialect)
1199
1200    if floor == d:
1201        return d
1202
1203    return floor + interval(unit)
def boolean_literal(condition):
1206def boolean_literal(condition):
1207    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1236def gen(expression: t.Any) -> str:
1237    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1238
1239    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1240    generator is expensive so we have a bare minimum sql generator here.
1241    """
1242    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1245class Gen:
1246    def __init__(self):
1247        self.stack = []
1248        self.sqls = []
1249
1250    def gen(self, expression: exp.Expression) -> str:
1251        self.stack = [expression]
1252        self.sqls.clear()
1253
1254        while self.stack:
1255            node = self.stack.pop()
1256
1257            if isinstance(node, exp.Expression):
1258                exp_handler_name = f"{node.key}_sql"
1259
1260                if hasattr(self, exp_handler_name):
1261                    getattr(self, exp_handler_name)(node)
1262                elif isinstance(node, exp.Func):
1263                    self._function(node)
1264                else:
1265                    key = node.key.upper()
1266                    self.stack.append(f"{key} " if self._args(node) else key)
1267            elif type(node) is list:
1268                for n in reversed(node):
1269                    if n is not None:
1270                        self.stack.extend((n, ","))
1271                if node:
1272                    self.stack.pop()
1273            else:
1274                if node is not None:
1275                    self.sqls.append(str(node))
1276
1277        return "".join(self.sqls)
1278
1279    def add_sql(self, e: exp.Add) -> None:
1280        self._binary(e, " + ")
1281
1282    def alias_sql(self, e: exp.Alias) -> None:
1283        self.stack.extend(
1284            (
1285                e.args.get("alias"),
1286                " AS ",
1287                e.args.get("this"),
1288            )
1289        )
1290
1291    def and_sql(self, e: exp.And) -> None:
1292        self._binary(e, " AND ")
1293
1294    def anonymous_sql(self, e: exp.Anonymous) -> None:
1295        this = e.this
1296        if isinstance(this, str):
1297            name = this.upper()
1298        elif isinstance(this, exp.Identifier):
1299            name = this.this
1300            name = f'"{name}"' if this.quoted else name.upper()
1301        else:
1302            raise ValueError(
1303                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1304            )
1305
1306        self.stack.extend(
1307            (
1308                ")",
1309                e.expressions,
1310                "(",
1311                name,
1312            )
1313        )
1314
1315    def between_sql(self, e: exp.Between) -> None:
1316        self.stack.extend(
1317            (
1318                e.args.get("high"),
1319                " AND ",
1320                e.args.get("low"),
1321                " BETWEEN ",
1322                e.this,
1323            )
1324        )
1325
1326    def boolean_sql(self, e: exp.Boolean) -> None:
1327        self.stack.append("TRUE" if e.this else "FALSE")
1328
1329    def bracket_sql(self, e: exp.Bracket) -> None:
1330        self.stack.extend(
1331            (
1332                "]",
1333                e.expressions,
1334                "[",
1335                e.this,
1336            )
1337        )
1338
1339    def column_sql(self, e: exp.Column) -> None:
1340        for p in reversed(e.parts):
1341            self.stack.extend((p, "."))
1342        self.stack.pop()
1343
1344    def datatype_sql(self, e: exp.DataType) -> None:
1345        self._args(e, 1)
1346        self.stack.append(f"{e.this.name} ")
1347
1348    def div_sql(self, e: exp.Div) -> None:
1349        self._binary(e, " / ")
1350
1351    def dot_sql(self, e: exp.Dot) -> None:
1352        self._binary(e, ".")
1353
1354    def eq_sql(self, e: exp.EQ) -> None:
1355        self._binary(e, " = ")
1356
1357    def from_sql(self, e: exp.From) -> None:
1358        self.stack.extend((e.this, "FROM "))
1359
1360    def gt_sql(self, e: exp.GT) -> None:
1361        self._binary(e, " > ")
1362
1363    def gte_sql(self, e: exp.GTE) -> None:
1364        self._binary(e, " >= ")
1365
1366    def identifier_sql(self, e: exp.Identifier) -> None:
1367        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1368
1369    def ilike_sql(self, e: exp.ILike) -> None:
1370        self._binary(e, " ILIKE ")
1371
1372    def in_sql(self, e: exp.In) -> None:
1373        self.stack.append(")")
1374        self._args(e, 1)
1375        self.stack.extend(
1376            (
1377                "(",
1378                " IN ",
1379                e.this,
1380            )
1381        )
1382
1383    def intdiv_sql(self, e: exp.IntDiv) -> None:
1384        self._binary(e, " DIV ")
1385
1386    def is_sql(self, e: exp.Is) -> None:
1387        self._binary(e, " IS ")
1388
1389    def like_sql(self, e: exp.Like) -> None:
1390        self._binary(e, " Like ")
1391
1392    def literal_sql(self, e: exp.Literal) -> None:
1393        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1394
1395    def lt_sql(self, e: exp.LT) -> None:
1396        self._binary(e, " < ")
1397
1398    def lte_sql(self, e: exp.LTE) -> None:
1399        self._binary(e, " <= ")
1400
1401    def mod_sql(self, e: exp.Mod) -> None:
1402        self._binary(e, " % ")
1403
1404    def mul_sql(self, e: exp.Mul) -> None:
1405        self._binary(e, " * ")
1406
1407    def neg_sql(self, e: exp.Neg) -> None:
1408        self._unary(e, "-")
1409
1410    def neq_sql(self, e: exp.NEQ) -> None:
1411        self._binary(e, " <> ")
1412
1413    def not_sql(self, e: exp.Not) -> None:
1414        self._unary(e, "NOT ")
1415
1416    def null_sql(self, e: exp.Null) -> None:
1417        self.stack.append("NULL")
1418
1419    def or_sql(self, e: exp.Or) -> None:
1420        self._binary(e, " OR ")
1421
1422    def paren_sql(self, e: exp.Paren) -> None:
1423        self.stack.extend(
1424            (
1425                ")",
1426                e.this,
1427                "(",
1428            )
1429        )
1430
1431    def sub_sql(self, e: exp.Sub) -> None:
1432        self._binary(e, " - ")
1433
1434    def subquery_sql(self, e: exp.Subquery) -> None:
1435        self._args(e, 2)
1436        alias = e.args.get("alias")
1437        if alias:
1438            self.stack.append(alias)
1439        self.stack.extend((")", e.this, "("))
1440
1441    def table_sql(self, e: exp.Table) -> None:
1442        self._args(e, 4)
1443        alias = e.args.get("alias")
1444        if alias:
1445            self.stack.append(alias)
1446        for p in reversed(e.parts):
1447            self.stack.extend((p, "."))
1448        self.stack.pop()
1449
1450    def tablealias_sql(self, e: exp.TableAlias) -> None:
1451        columns = e.columns
1452
1453        if columns:
1454            self.stack.extend((")", columns, "("))
1455
1456        self.stack.extend((e.this, " AS "))
1457
1458    def var_sql(self, e: exp.Var) -> None:
1459        self.stack.append(e.this)
1460
1461    def _binary(self, e: exp.Binary, op: str) -> None:
1462        self.stack.extend((e.expression, op, e.this))
1463
1464    def _unary(self, e: exp.Unary, op: str) -> None:
1465        self.stack.extend((e.this, op))
1466
1467    def _function(self, e: exp.Func) -> None:
1468        self.stack.extend(
1469            (
1470                ")",
1471                list(e.args.values()),
1472                "(",
1473                e.sql_name(),
1474            )
1475        )
1476
1477    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1478        kvs = []
1479        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1480
1481        for k in arg_types or arg_types:
1482            v = node.args.get(k)
1483
1484            if v is not None:
1485                kvs.append([f":{k}", v])
1486        if kvs:
1487            self.stack.append(kvs)
1488            return True
1489        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1250    def gen(self, expression: exp.Expression) -> str:
1251        self.stack = [expression]
1252        self.sqls.clear()
1253
1254        while self.stack:
1255            node = self.stack.pop()
1256
1257            if isinstance(node, exp.Expression):
1258                exp_handler_name = f"{node.key}_sql"
1259
1260                if hasattr(self, exp_handler_name):
1261                    getattr(self, exp_handler_name)(node)
1262                elif isinstance(node, exp.Func):
1263                    self._function(node)
1264                else:
1265                    key = node.key.upper()
1266                    self.stack.append(f"{key} " if self._args(node) else key)
1267            elif type(node) is list:
1268                for n in reversed(node):
1269                    if n is not None:
1270                        self.stack.extend((n, ","))
1271                if node:
1272                    self.stack.pop()
1273            else:
1274                if node is not None:
1275                    self.sqls.append(str(node))
1276
1277        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1279    def add_sql(self, e: exp.Add) -> None:
1280        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1282    def alias_sql(self, e: exp.Alias) -> None:
1283        self.stack.extend(
1284            (
1285                e.args.get("alias"),
1286                " AS ",
1287                e.args.get("this"),
1288            )
1289        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1291    def and_sql(self, e: exp.And) -> None:
1292        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1294    def anonymous_sql(self, e: exp.Anonymous) -> None:
1295        this = e.this
1296        if isinstance(this, str):
1297            name = this.upper()
1298        elif isinstance(this, exp.Identifier):
1299            name = this.this
1300            name = f'"{name}"' if this.quoted else name.upper()
1301        else:
1302            raise ValueError(
1303                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1304            )
1305
1306        self.stack.extend(
1307            (
1308                ")",
1309                e.expressions,
1310                "(",
1311                name,
1312            )
1313        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1315    def between_sql(self, e: exp.Between) -> None:
1316        self.stack.extend(
1317            (
1318                e.args.get("high"),
1319                " AND ",
1320                e.args.get("low"),
1321                " BETWEEN ",
1322                e.this,
1323            )
1324        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1326    def boolean_sql(self, e: exp.Boolean) -> None:
1327        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1329    def bracket_sql(self, e: exp.Bracket) -> None:
1330        self.stack.extend(
1331            (
1332                "]",
1333                e.expressions,
1334                "[",
1335                e.this,
1336            )
1337        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1339    def column_sql(self, e: exp.Column) -> None:
1340        for p in reversed(e.parts):
1341            self.stack.extend((p, "."))
1342        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1344    def datatype_sql(self, e: exp.DataType) -> None:
1345        self._args(e, 1)
1346        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1348    def div_sql(self, e: exp.Div) -> None:
1349        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1351    def dot_sql(self, e: exp.Dot) -> None:
1352        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1354    def eq_sql(self, e: exp.EQ) -> None:
1355        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1357    def from_sql(self, e: exp.From) -> None:
1358        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1360    def gt_sql(self, e: exp.GT) -> None:
1361        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1363    def gte_sql(self, e: exp.GTE) -> None:
1364        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1366    def identifier_sql(self, e: exp.Identifier) -> None:
1367        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1369    def ilike_sql(self, e: exp.ILike) -> None:
1370        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1372    def in_sql(self, e: exp.In) -> None:
1373        self.stack.append(")")
1374        self._args(e, 1)
1375        self.stack.extend(
1376            (
1377                "(",
1378                " IN ",
1379                e.this,
1380            )
1381        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1383    def intdiv_sql(self, e: exp.IntDiv) -> None:
1384        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1386    def is_sql(self, e: exp.Is) -> None:
1387        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1389    def like_sql(self, e: exp.Like) -> None:
1390        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1392    def literal_sql(self, e: exp.Literal) -> None:
1393        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1395    def lt_sql(self, e: exp.LT) -> None:
1396        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1398    def lte_sql(self, e: exp.LTE) -> None:
1399        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1401    def mod_sql(self, e: exp.Mod) -> None:
1402        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1404    def mul_sql(self, e: exp.Mul) -> None:
1405        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1407    def neg_sql(self, e: exp.Neg) -> None:
1408        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1410    def neq_sql(self, e: exp.NEQ) -> None:
1411        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1413    def not_sql(self, e: exp.Not) -> None:
1414        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1416    def null_sql(self, e: exp.Null) -> None:
1417        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1419    def or_sql(self, e: exp.Or) -> None:
1420        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1422    def paren_sql(self, e: exp.Paren) -> None:
1423        self.stack.extend(
1424            (
1425                ")",
1426                e.this,
1427                "(",
1428            )
1429        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1431    def sub_sql(self, e: exp.Sub) -> None:
1432        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1434    def subquery_sql(self, e: exp.Subquery) -> None:
1435        self._args(e, 2)
1436        alias = e.args.get("alias")
1437        if alias:
1438            self.stack.append(alias)
1439        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1441    def table_sql(self, e: exp.Table) -> None:
1442        self._args(e, 4)
1443        alias = e.args.get("alias")
1444        if alias:
1445            self.stack.append(alias)
1446        for p in reversed(e.parts):
1447            self.stack.extend((p, "."))
1448        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1450    def tablealias_sql(self, e: exp.TableAlias) -> None:
1451        columns = e.columns
1452
1453        if columns:
1454            self.stack.extend((")", columns, "("))
1455
1456        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1458    def var_sql(self, e: exp.Var) -> None:
1459        self.stack.append(e.this)