Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions datafusion/physical-plan/src/recursive_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl RecursiveQueryExec {
is_distinct: bool,
) -> Result<Self> {
// Each recursive query needs its own work table
let work_table = Arc::new(WorkTable::new());
let work_table = Arc::new(WorkTable::new(name.clone()));
// Use the same work table for both the WorkTableExec and the recursive term
let recursive_term = assign_work_table(recursive_term, &work_table)?;
let cache = Self::compute_properties(static_term.schema());
Expand Down Expand Up @@ -380,8 +380,6 @@ fn assign_work_table(
work_table_refs += 1;
Ok(Transformed::yes(new_plan))
}
} else if plan.as_any().is::<RecursiveQueryExec>() {
not_impl_err!("Recursive queries cannot be nested")
} else {
Ok(Transformed::no(plan))
}
Expand Down
14 changes: 10 additions & 4 deletions datafusion/physical-plan/src/work_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ impl ReservedBatches {
#[derive(Debug)]
pub struct WorkTable {
batches: Mutex<Option<ReservedBatches>>,
name: String,
}

impl WorkTable {
/// Create a new work table.
pub(super) fn new() -> Self {
pub(super) fn new(name: String) -> Self {
Self {
batches: Mutex::new(None),
name,
}
}

Expand Down Expand Up @@ -116,10 +118,10 @@ impl WorkTableExec {
pub fn new(name: String, schema: SchemaRef) -> Self {
let cache = Self::compute_properties(Arc::clone(&schema));
Self {
name,
name: name.clone(),
schema,
metrics: ExecutionPlanMetricsSet::new(),
work_table: Arc::new(WorkTable::new()),
work_table: Arc::new(WorkTable::new(name)),
cache,
}
}
Expand Down Expand Up @@ -233,6 +235,10 @@ impl ExecutionPlan for WorkTableExec {
// Down-cast to the expected state type; propagate `None` on failure
let work_table = state.downcast::<WorkTable>().ok()?;

if work_table.name != self.name {
return None; // Different table
}

Some(Arc::new(Self {
name: self.name.clone(),
schema: Arc::clone(&self.schema),
Expand All @@ -251,7 +257,7 @@ mod tests {

#[test]
fn test_work_table() {
let work_table = WorkTable::new();
let work_table = WorkTable::new("test".into());
// Can't take from empty work_table
assert!(work_table.take().is_err());

Expand Down
69 changes: 59 additions & 10 deletions datafusion/sqllogictest/test_files/cte.slt
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,25 @@ SELECT * FROM nodes
3
4

# deduplicating recursive CTE with two variables works
query II
WITH RECURSIVE ranges AS (
SELECT min, max from (VALUES (1, 1), (2, 2)) ranges(min, max)
UNION
SELECT min, max + 1 as max
FROM ranges
WHERE max < 4
)
SELECT * FROM ranges
----
1 1
2 2
1 2
2 3
1 3
2 4
1 4

# setup
statement ok
CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true');
Expand Down Expand Up @@ -647,21 +666,51 @@ ORDER BY
3 1400 1
1 2700 2

#expect error from recursive CTE with nested recursive terms
query error DataFusion error: This feature is not implemented: Recursive queries cannot be nested
#nested recursive ctes
query I
WITH RECURSIVE outer_cte AS (
SELECT 1 as a
UNION ALL (
WITH RECURSIVE nested_cte AS (
WITH RECURSIVE nested_cte AS (
SELECT 1 as a
UNION ALL
SELECT a+2 as a
FROM nested_cte where a < 3
)
SELECT outer_cte.a +2
FROM outer_cte JOIN nested_cte USING(a)
WHERE nested_cte.a < 4
)
SELECT a + 2 as a
FROM nested_cte where a < 3
)
SELECT outer_cte.a + 2 as a
FROM outer_cte JOIN nested_cte USING(a)
WHERE nested_cte.a < 4
)
)
SELECT a FROM outer_cte;
----
1
3
5

# Check that CTE name shadowing is returning an error
query error DataFusion error: Error during planning: WITH query name "outer_cte" specified more than once
WITH RECURSIVE outer_cte AS (
SELECT 1 as a
UNION ALL (
WITH RECURSIVE nested_cte AS (
SELECT 1 as a
UNION ALL (
WITH RECURSIVE outer_cte AS (
SELECT 1 as a
UNION ALL
SELECT a + 2 as a
FROM outer_cte where a < 3
)
SELECT nested_cte.a + outer_cte.a as a
FROM nested_cte JOIN outer_cte USING(a)
WHERE outer_cte_cte.a < 8
)
)
SELECT outer_cte.a + nested_cte.a as a
FROM outer_cte JOIN nested_cte USING(a)
WHERE nested_cte.a < 8
)
)
SELECT a FROM outer_cte;

Expand Down