Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,21 @@ def _regr_val_sql(
)


def _maybe_corr_null_to_false(
expression: t.Union[exp.Filter, exp.Window, exp.Corr],
) -> t.Optional[t.Union[exp.Filter, exp.Window, exp.Corr]]:
expr = expression.copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this copy necessary ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A window/filter AST containing CORR could be modified here then passed through to other rendering methods.

I was under the impression that, when modifying an AST, we should copy if it will then be passed to unknown downstream consumers.

It sounds like that's not correct?

Copy link
Collaborator

@geooo109 geooo109 Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, I see the point, ^ this works.

I was thinking that .set doesn't break anything here, so we could skip the copy in some cases.

What do you think @VaggelisD ?

corr = expr
while isinstance(corr, (exp.Window, exp.Filter)):
corr = corr.this

if not isinstance(corr, exp.Corr) or not corr.args.get("null_on_zero_variance"):
return None

corr.set("null_on_zero_variance", False)
return expr


class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = True
Expand Down Expand Up @@ -1292,6 +1307,7 @@ class Generator(generator.Generator):
exp.BitwiseOrAgg: _bitwise_agg_sql,
exp.BitwiseXorAgg: _bitwise_agg_sql,
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.Corr: lambda self, e: self._corr_sql(e),
exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
Expand Down Expand Up @@ -2347,3 +2363,37 @@ def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str:
result_sql = f"~{self.sql(expression, 'this')}"

return _gen_with_cast_to_blob(self, expression, result_sql)

def window_sql(self, expression: exp.Window) -> str:
this = expression.this
if isinstance(this, exp.Corr) or (
isinstance(this, exp.Filter) and isinstance(this.this, exp.Corr)
):
return self._corr_sql(expression)

return super().window_sql(expression)

def filter_sql(self, expression: exp.Filter) -> str:
if isinstance(expression.this, exp.Corr):
return self._corr_sql(expression)

return super().filter_sql(expression)

def _corr_sql(
self,
expression: t.Union[exp.Filter, exp.Window, exp.Corr],
) -> str:
if isinstance(expression, exp.Corr) and not expression.args.get(
"null_on_zero_variance"
):
return self.func("CORR", expression.this, expression.expression)

corr_expr = _maybe_corr_null_to_false(expression)
if corr_expr is None:
if isinstance(expression, exp.Window):
return super().window_sql(expression)
if isinstance(expression, exp.Filter):
return super().filter_sql(expression)
corr_expr = expression # make mypy happy

return self.sql(exp.case().when(exp.IsNan(this=corr_expr), exp.null()).else_(corr_expr))
5 changes: 5 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,11 @@ class Parser(parser.Parser):
"BIT_XORAGG": exp.BitwiseXorAgg.from_arg_list,
"BITMAP_OR_AGG": exp.BitmapOrAgg.from_arg_list,
"BOOLXOR": _build_bitwise(exp.Xor, "BOOLXOR"),
"CORR": lambda args: exp.Corr(
this=seq_get(args, 0),
expression=seq_get(args, 1),
null_on_zero_variance=True,
),
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
"DATE_TRUNC": _date_trunc_to_time,
"DATEADD": _build_date_time_add(exp.DateAdd),
Expand Down
5 changes: 4 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8258,7 +8258,10 @@ class Upper(Func):


class Corr(Binary, AggFunc):
pass
# Correlation divides by variance(column). If a column has 0 variance, the denominator
# is 0 - some dialects return NaN (DuckDB) while others return NULL (Snowflake).
# `null_on_zero_variance` is set to True at parse time for dialects that return NULL.
arg_types = {"this": True, "expression": True, "null_on_zero_variance": False}


# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CUME_DIST.html
Expand Down
30 changes: 30 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,36 @@ def test_variance(self):
},
)

def test_corr(self):
self.validate_all(
"SELECT CORR(a, b)",
write={
"duckdb": "SELECT CORR(a, b)",
"postgres": "SELECT CORR(a, b)",
},
)
self.validate_all(
"SELECT CORR(a, b) OVER (PARTITION BY c)",
write={
"duckdb": "SELECT CORR(a, b) OVER (PARTITION BY c)",
"postgres": "SELECT CORR(a, b) OVER (PARTITION BY c)",
},
)
self.validate_all(
"SELECT CORR(a, b) FILTER(WHERE c > 0)",
write={
"duckdb": "SELECT CORR(a, b) FILTER(WHERE c > 0)",
"postgres": "SELECT CORR(a, b) FILTER(WHERE c > 0)",
},
)
self.validate_all(
"SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
write={
"duckdb": "SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
"postgres": "SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
},
)

def test_regexp_binary(self):
"""See https://github.com/tobymao/sqlglot/pull/2404 for details."""
self.assertIsInstance(self.parse_one("'thomas' ~ '.*thomas.*'"), exp.Binary)
Expand Down
40 changes: 39 additions & 1 deletion tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,6 @@ def test_snowflake(self):
},
)
for func in (
"CORR",
"COVAR_POP",
"COVAR_SAMP",
):
Expand Down Expand Up @@ -4437,6 +4436,45 @@ def test_ceil(self):
},
)

def test_corr(self):
self.validate_all(
"SELECT CORR(a, b)",
read={
"snowflake": "SELECT CORR(a, b)",
"postgres": "SELECT CORR(a, b)",
},
write={
"snowflake": "SELECT CORR(a, b)",
"postgres": "SELECT CORR(a, b)",
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b)) THEN NULL ELSE CORR(a, b) END",
},
)
self.validate_all(
"SELECT CORR(a, b) OVER (PARTITION BY c)",
read={
"snowflake": "SELECT CORR(a, b) OVER (PARTITION BY c)",
"postgres": "SELECT CORR(a, b) OVER (PARTITION BY c)",
},
write={
"snowflake": "SELECT CORR(a, b) OVER (PARTITION BY c)",
"postgres": "SELECT CORR(a, b) OVER (PARTITION BY c)",
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b) OVER (PARTITION BY c)) THEN NULL ELSE CORR(a, b) OVER (PARTITION BY c) END",
},
)

self.validate_all(
"SELECT CORR(a, b) FILTER(WHERE c > 0)",
write={
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b) FILTER(WHERE c > 0)) THEN NULL ELSE CORR(a, b) FILTER(WHERE c > 0) END",
},
)
self.validate_all(
"SELECT CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)",
write={
"duckdb": "SELECT CASE WHEN ISNAN(CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d)) THEN NULL ELSE CORR(a, b) FILTER(WHERE c > 0) OVER (PARTITION BY d) END",
},
)

def test_update_statement(self):
self.validate_identity("UPDATE test SET t = 1 FROM t1")
self.validate_identity("UPDATE test SET t = 1 FROM t2 JOIN t3 ON t2.id = t3.id")
Expand Down