diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index e2df8f9578f97..6aac7f4d5726b 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -86,7 +86,7 @@ impl RecursiveQueryExec { is_distinct: bool, ) -> Result { // Each recursive query needs its own work table - let work_table = Arc::new(WorkTable::new()); + let work_table = Arc::new(WorkTable::new(name.clone())); // Use the same work table for both the WorkTableExec and the recursive term let recursive_term = assign_work_table(recursive_term, &work_table)?; let cache = Self::compute_properties(static_term.schema()); @@ -380,8 +380,6 @@ fn assign_work_table( work_table_refs += 1; Ok(Transformed::yes(new_plan)) } - } else if plan.as_any().is::() { - not_impl_err!("Recursive queries cannot be nested") } else { Ok(Transformed::no(plan)) } diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index a77e7b2cf10fc..e8b1949493a95 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -61,13 +61,15 @@ impl ReservedBatches { #[derive(Debug)] pub struct WorkTable { batches: Mutex>, + name: String, } impl WorkTable { /// Create a new work table. - pub(super) fn new() -> Self { + pub(super) fn new(name: String) -> Self { Self { batches: Mutex::new(None), + name, } } @@ -116,10 +118,10 @@ impl WorkTableExec { pub fn new(name: String, schema: SchemaRef) -> Self { let cache = Self::compute_properties(Arc::clone(&schema)); Self { - name, + name: name.clone(), schema, metrics: ExecutionPlanMetricsSet::new(), - work_table: Arc::new(WorkTable::new()), + work_table: Arc::new(WorkTable::new(name)), cache, } } @@ -233,6 +235,10 @@ impl ExecutionPlan for WorkTableExec { // Down-cast to the expected state type; propagate `None` on failure let work_table = state.downcast::().ok()?; + if work_table.name != self.name { + return None; // Different table + } + Some(Arc::new(Self { name: self.name.clone(), schema: Arc::clone(&self.schema), @@ -251,7 +257,7 @@ mod tests { #[test] fn test_work_table() { - let work_table = WorkTable::new(); + let work_table = WorkTable::new("test".into()); // Can't take from empty work_table assert!(work_table.take().is_err()); diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index fe9077b7f8dc9..fc32c6656fee7 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -125,6 +125,25 @@ SELECT * FROM nodes 3 4 +# deduplicating recursive CTE with two variables works +query II +WITH RECURSIVE ranges AS ( + SELECT min, max from (VALUES (1, 1), (2, 2)) ranges(min, max) + UNION + SELECT min, max + 1 as max + FROM ranges + WHERE max < 4 +) +SELECT * FROM ranges +---- +1 1 +2 2 +1 2 +2 3 +1 3 +2 4 +1 4 + # setup statement ok CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true'); @@ -647,21 +666,51 @@ ORDER BY 3 1400 1 1 2700 2 -#expect error from recursive CTE with nested recursive terms -query error DataFusion error: This feature is not implemented: Recursive queries cannot be nested +#nested recursive ctes +query I WITH RECURSIVE outer_cte AS ( SELECT 1 as a UNION ALL ( - WITH RECURSIVE nested_cte AS ( + WITH RECURSIVE nested_cte AS ( SELECT 1 as a UNION ALL - SELECT a+2 as a - FROM nested_cte where a < 3 - ) - SELECT outer_cte.a +2 - FROM outer_cte JOIN nested_cte USING(a) - WHERE nested_cte.a < 4 - ) + SELECT a + 2 as a + FROM nested_cte where a < 3 + ) + SELECT outer_cte.a + 2 as a + FROM outer_cte JOIN nested_cte USING(a) + WHERE nested_cte.a < 4 + ) +) +SELECT a FROM outer_cte; +---- +1 +3 +5 + +# Check that CTE name shadowing is returning an error +query error DataFusion error: Error during planning: WITH query name "outer_cte" specified more than once +WITH RECURSIVE outer_cte AS ( + SELECT 1 as a + UNION ALL ( + WITH RECURSIVE nested_cte AS ( + SELECT 1 as a + UNION ALL ( + WITH RECURSIVE outer_cte AS ( + SELECT 1 as a + UNION ALL + SELECT a + 2 as a + FROM outer_cte where a < 3 + ) + SELECT nested_cte.a + outer_cte.a as a + FROM nested_cte JOIN outer_cte USING(a) + WHERE outer_cte_cte.a < 8 + ) + ) + SELECT outer_cte.a + nested_cte.a as a + FROM outer_cte JOIN nested_cte USING(a) + WHERE nested_cte.a < 8 + ) ) SELECT a FROM outer_cte;