From 40d67e5726c94374d878429a636b9f6b9ef2e93c Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Jan 2026 14:00:44 -0500 Subject: [PATCH 1/6] Refactor TaskContext to hold and use a Durable client We now delegate to the existing 'spawn_by_name' (wrapping it in a checkpoint in TaskContext). This lets us re-use all of the existing logic, including the OpenTelemetry context propagation logic. This will give us a tree structure in OpenTelemetry - subtasks will use their parent task as the parent trace --- src/client.rs | 40 ++++++++++--- src/context.rs | 153 +++++++++++-------------------------------------- src/error.rs | 11 ++++ src/worker.rs | 105 +++++++++------------------------ 4 files changed, 104 insertions(+), 205 deletions(-) diff --git a/src/client.rs b/src/client.rs index 28c2383..8fae450 100644 --- a/src/client.rs +++ b/src/client.rs @@ -115,6 +115,27 @@ where state: State, } +impl Durable { + /// TODO: Decide if we want to implement `Clone`, + /// which will allow consumers to clone `Durable` + /// Currently, we only allow cloning with in the crate + /// via this method + pub(crate) fn clone_inner(&self) -> Durable { + Durable { + pool: self.pool.clone(), + owns_pool: self.owns_pool, + queue_name: self.queue_name.clone(), + spawn_defaults: self.spawn_defaults.clone(), + registry: self.registry.clone(), + state: self.state.clone(), + } + } + + pub(crate) fn registry(&self) -> &Arc>> { + &self.registry + } +} + /// Builder for configuring a [`Durable`] client. /// /// # Example @@ -692,14 +713,17 @@ where }); } - Ok(Worker::start( - self.pool.clone(), - self.queue_name.clone(), - self.registry.clone(), - options, - self.state.clone(), - self.spawn_defaults.clone(), - ) + // For now, we just manually construct a `Durable` with clones + // of our fields. In the future, we may want to make `Durable` + // implement `Clone`, or make an inner struct and wrap it in an `Arc` + Ok(Worker::start(Durable { + pool: self.pool.clone(), + owns_pool: self.owns_pool, + queue_name: self.queue_name.clone(), + spawn_defaults: self.spawn_defaults.clone(), + registry: self.registry.clone(), + state: self.state.clone(), + }, options) .await) } diff --git a/src/context.rs b/src/context.rs index b3128c4..a1a5710 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,18 +1,16 @@ use chrono::{DateTime, Utc}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value as JsonValue; -use sqlx::PgPool; use std::collections::HashMap; -use std::sync::Arc; use std::time::Duration; -use tokio::sync::RwLock; use uuid::Uuid; +use crate::Durable; use crate::error::{ControlFlow, TaskError, TaskResult}; -use crate::task::{Task, TaskRegistry}; +use crate::task::Task; use crate::types::{ - AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnDefaults, - SpawnOptions, SpawnResultRow, TaskHandle, + AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions, + TaskHandle, }; use crate::worker::LeaseExtender; @@ -52,14 +50,9 @@ where pub attempt: i32, // Internal state - pool: PgPool, - queue_name: String, - #[allow(dead_code)] + durable: Durable, task: ClaimedTask, claim_timeout: Duration, - - state: State, - /// Checkpoint cache: loaded on creation, updated on writes. checkpoint_cache: HashMap, @@ -69,12 +62,6 @@ where /// Notifies the worker when the lease is extended via step() or heartbeat(). lease_extender: LeaseExtender, - - /// Task registry for validating spawn_by_name calls. - registry: Arc>>, - - /// Default settings for subtasks spawned via spawn/spawn_by_name. - spawn_defaults: SpawnDefaults, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -95,23 +82,19 @@ where /// Loads all existing checkpoints into the cache. #[allow(clippy::too_many_arguments)] pub(crate) async fn create( - pool: PgPool, - queue_name: String, + durable: Durable, task: ClaimedTask, claim_timeout: Duration, lease_extender: LeaseExtender, - registry: Arc>>, - state: State, - spawn_defaults: SpawnDefaults, ) -> Result { // Load all checkpoints for this task into cache let checkpoints: Vec = sqlx::query_as( "SELECT checkpoint_name, state, owner_run_id, updated_at FROM durable.get_task_checkpoint_states($1, $2)", ) - .bind(&queue_name) + .bind(durable.queue_name()) .bind(task.task_id) - .fetch_all(&pool) + .fetch_all(durable.pool()) .await?; let mut cache = HashMap::new(); @@ -123,16 +106,12 @@ where task_id: task.task_id, run_id: task.run_id, attempt: task.attempt, - pool, - queue_name, + durable, task, claim_timeout, checkpoint_cache: cache, step_counters: HashMap::new(), lease_extender, - registry, - state, - spawn_defaults, }) } @@ -206,7 +185,7 @@ where } // Execute the step - let result = f(params, self.state.clone()).await?; + let result = f(params, self.durable.state().clone()).await?; // Persist checkpoint (also extends claim lease) #[cfg(feature = "telemetry")] @@ -275,13 +254,13 @@ where // set_task_checkpoint_state also extends the claim let query = "SELECT durable.set_task_checkpoint_state($1, $2, $3, $4, $5, $6)"; sqlx::query(query) - .bind(&self.queue_name) + .bind(self.durable.queue_name()) .bind(self.task_id) .bind(name) .bind(&state_json) .bind(self.run_id) .bind(self.claim_timeout.as_secs() as i32) - .execute(&self.pool) + .execute(self.durable.pool()) .await?; self.checkpoint_cache.insert(name.to_string(), state_json); @@ -315,12 +294,12 @@ where let (needs_suspend,): (bool,) = sqlx::query_as("SELECT durable.sleep_for($1, $2, $3, $4, $5)") - .bind(&self.queue_name) + .bind(self.durable.queue_name()) .bind(self.task_id) .bind(self.run_id) .bind(&checkpoint_name) .bind(duration_ms) - .fetch_one(&self.pool) + .fetch_one(self.durable.pool()) .await?; if needs_suspend { @@ -390,13 +369,13 @@ where FROM durable.await_event($1, $2, $3, $4, $5, $6)"; let result: AwaitEventResult = sqlx::query_as(query) - .bind(&self.queue_name) + .bind(self.durable.queue_name()) .bind(self.task_id) .bind(self.run_id) .bind(&checkpoint_name) .bind(event_name) .bind(timeout_secs) - .fetch_one(&self.pool) + .fetch_one(self.durable.pool()) .await?; if result.should_suspend { @@ -434,10 +413,10 @@ where let payload_json = serde_json::to_value(payload)?; let query = "SELECT durable.emit_event($1, $2, $3)"; sqlx::query(query) - .bind(&self.queue_name) + .bind(self.durable.queue_name()) .bind(event_name) .bind(&payload_json) - .execute(&self.pool) + .execute(self.durable.pool()) .await?; Ok(()) @@ -474,10 +453,10 @@ where let query = "SELECT durable.extend_claim($1, $2, $3)"; sqlx::query(query) - .bind(&self.queue_name) + .bind(self.durable.queue_name()) .bind(self.run_id) .bind(extend_by.as_secs() as i32) - .execute(&self.pool) + .execute(self.durable.pool()) .await?; // Notify worker that lease was extended so it can reset timers @@ -642,82 +621,25 @@ where validate_user_name(name)?; let checkpoint_name = self.get_checkpoint_name(&format!("$spawn:{name}"), ¶ms)?; - // Validate headers don't use reserved prefix - if let Some(ref headers) = options.headers { - for key in headers.keys() { - if key.starts_with("durable::") { - return Err(TaskError::Validation { - message: format!( - "Header key '{}' uses reserved prefix 'durable::'. User headers cannot start with 'durable::'.", - key - ), - }); - } - } - } - // Return cached task_id if already spawned if let Some(cached) = self.checkpoint_cache.get(&checkpoint_name) { let task_id: Uuid = serde_json::from_value(cached.clone())?; return Ok(TaskHandle::new(task_id)); } - // Validate that the task is registered - { - let registry = self.registry.read().await; - if !registry.contains_key(task_name) { - return Err(TaskError::Validation { - message: format!( - "Unknown task: {}. Task must be registered before spawning.", - task_name - ), - }); - } - } - - // Apply defaults if not set - let options = SpawnOptions { - max_attempts: Some( - options - .max_attempts - .unwrap_or(self.spawn_defaults.max_attempts), - ), - retry_strategy: options - .retry_strategy - .or_else(|| self.spawn_defaults.retry_strategy.clone()), - cancellation: options - .cancellation - .or_else(|| self.spawn_defaults.cancellation.clone()), - ..options - }; - - // Build options JSON, merging user options with parent_task_id - #[derive(Serialize)] - struct SubtaskOptions<'a> { - parent_task_id: Uuid, - #[serde(flatten)] - options: &'a SpawnOptions, - } - let options_json = serde_json::to_value(SubtaskOptions { - parent_task_id: self.task_id, - options: &options, - })?; - - let row: SpawnResultRow = sqlx::query_as( - "SELECT task_id, run_id, attempt FROM durable.spawn_task($1, $2, $3, $4)", - ) - .bind(&self.queue_name) - .bind(task_name) - .bind(¶ms) - .bind(&options_json) - .fetch_one(&self.pool) - .await?; - + let spawned_task = self + .durable + .spawn_by_name(task_name, params, options) + .await + .map_err(|e| TaskError::SubtaskSpawnFailed { + name: task_name.to_string(), + error: e, + })?; // Checkpoint the spawn - self.persist_checkpoint(&checkpoint_name, &row.task_id) + self.persist_checkpoint(&checkpoint_name, &spawned_task.task_id) .await?; - Ok(TaskHandle::new(row.task_id)) + Ok(TaskHandle::new(spawned_task.task_id)) } /// Wait for a subtask to complete and return its result. @@ -781,13 +703,13 @@ where FROM durable.await_event($1, $2, $3, $4, $5, $6)"; let result: AwaitEventResult = sqlx::query_as(query) - .bind(&self.queue_name) + .bind(self.durable.queue_name()) .bind(self.task_id) .bind(self.run_id) .bind(&checkpoint_name) .bind(&event_name) .bind(None::) // No timeout - .fetch_one(&self.pool) + .fetch_one(self.durable.pool()) .await?; if result.should_suspend { @@ -837,6 +759,7 @@ mod tests { #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] use super::*; use crate::{Durable, MIGRATOR}; + use sqlx::PgPool; // Note that this is a 'unit' test in order to call private methods, but it still needs Postgres to be running #[sqlx::test(migrator = "MIGRATOR")] @@ -849,8 +772,7 @@ mod tests { .expect("Failed to create Durable client"); client.create_queue(None).await.unwrap(); let mut ctx = TaskContext::create( - pool, - "my_test_queue".to_string(), + client, ClaimedTask { task_id: Uuid::now_v7(), run_id: Uuid::now_v7(), @@ -865,13 +787,6 @@ mod tests { }, Duration::from_secs(10), LeaseExtender::dummy_for_tests(), - Arc::new(RwLock::new(TaskRegistry::new())), - (), - SpawnDefaults { - max_attempts: 5, - retry_strategy: None, - cancellation: None, - }, ) .await .unwrap(); diff --git a/src/error.rs b/src/error.rs index 8972ad5..c4d6cb2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -73,6 +73,10 @@ pub enum TaskError { #[error("serialization error: {0}")] Serialization(serde_json::Error), + //// Error occurred while trying to spawn a subtask + #[error("failed to spawn subtask `{name}`: {error}")] + SubtaskSpawnFailed { name: String, error: DurableError }, + /// A child task failed. /// /// Returned by [`TaskContext::join`](crate::TaskContext::join) when the child @@ -220,6 +224,13 @@ pub fn serialize_task_error(err: &TaskError) -> JsonValue { "message": e.to_string(), }) } + TaskError::SubtaskSpawnFailed { name, error } => { + serde_json::json!({ + "name": "SubtaskSpawnFailed", + "message": error.to_string(), + "name": name, + }) + } TaskError::ChildFailed { step_name, message } => { serde_json::json!({ "name": "ChildFailed", diff --git a/src/worker.rs b/src/worker.rs index f209b7b..00046f9 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -3,15 +3,15 @@ use serde_json::Value as JsonValue; use sqlx::PgPool; use std::sync::Arc; use std::time::Duration; -use tokio::sync::{RwLock, Semaphore, broadcast, mpsc}; +use tokio::sync::{Semaphore, broadcast, mpsc}; use tokio::time::{Instant, sleep, sleep_until}; use tracing::{Instrument, Span}; use uuid::Uuid; +use crate::Durable; use crate::context::TaskContext; use crate::error::{ControlFlow, TaskError, serialize_task_error}; -use crate::task::TaskRegistry; -use crate::types::{ClaimedTask, ClaimedTaskRow, SpawnDefaults, WorkerOptions}; +use crate::types::{ClaimedTask, ClaimedTaskRow, WorkerOptions}; /// Notifies the worker that the lease has been extended. /// Used by TaskContext to reset warning/fatal timers. @@ -61,14 +61,7 @@ pub struct Worker { } impl Worker { - pub(crate) async fn start( - pool: PgPool, - queue_name: String, - registry: Arc>>, - options: WorkerOptions, - state: State, - spawn_defaults: SpawnDefaults, - ) -> Self + pub(crate) async fn start(durable: Durable, options: WorkerOptions) -> Self where State: Clone + Send + Sync + 'static, { @@ -85,16 +78,7 @@ impl Worker { ) }); - let handle = tokio::spawn(Self::run_loop( - pool, - queue_name, - registry, - options, - worker_id, - shutdown_rx, - state, - spawn_defaults, - )); + let handle = tokio::spawn(Self::run_loop(durable, options, worker_id, shutdown_rx)); Self { shutdown_tx, @@ -113,14 +97,10 @@ impl Worker { #[allow(clippy::too_many_arguments)] async fn run_loop( - pool: PgPool, - queue_name: String, - registry: Arc>>, + durable: Durable, options: WorkerOptions, worker_id: String, mut shutdown_rx: broadcast::Receiver<()>, - state: State, - spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -172,8 +152,7 @@ impl Worker { } let tasks = match Self::claim_tasks( - &pool, - &queue_name, + &durable, &worker_id, claim_timeout, permits.len(), @@ -189,23 +168,14 @@ impl Worker { let permits = permits.into_iter().take(tasks.len()); for (task, permit) in tasks.into_iter().zip(permits) { - let pool = pool.clone(); - let queue_name = queue_name.clone(); - let registry = registry.clone(); let done_tx = done_tx.clone(); - let state = state.clone(); - let spawn_defaults = spawn_defaults.clone(); - + let durable = durable.clone_inner(); tokio::spawn(async move { Self::execute_task( - pool, - queue_name, - registry, + durable, task, claim_timeout, fatal_on_lease_timeout, - state, - spawn_defaults, ).await; drop(permit); @@ -226,9 +196,8 @@ impl Worker { fields(queue = %queue_name, worker_id = %worker_id, count = count) ) )] - async fn claim_tasks( - pool: &PgPool, - queue_name: &str, + async fn claim_tasks( + durable: &Durable, worker_id: &str, claim_timeout: Duration, count: usize, @@ -241,11 +210,11 @@ impl Worker { FROM durable.claim_task($1, $2, $3, $4)"; let rows: Vec = sqlx::query_as(query) - .bind(queue_name) + .bind(durable.queue_name()) .bind(worker_id) .bind(claim_timeout.as_secs() as i32) .bind(count as i32) - .fetch_all(pool) + .fetch_all(durable.pool()) .await?; let tasks: Vec = rows @@ -267,21 +236,17 @@ impl Worker { #[allow(clippy::too_many_arguments)] async fn execute_task( - pool: PgPool, - queue_name: String, - registry: Arc>>, + durable: Durable, task: ClaimedTask, claim_timeout: Duration, fatal_on_lease_timeout: bool, - state: State, - spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { // Create span for task execution, linked to parent trace context if available let span = tracing::info_span!( "durable.worker.execute_task", - queue = %queue_name, + queue = %durable.queue_name(), task_id = %task.task_id, run_id = %task.run_id, task_name = %task.task_name, @@ -300,30 +265,17 @@ impl Worker { } } - Self::execute_task_inner( - pool, - queue_name, - registry, - task, - claim_timeout, - fatal_on_lease_timeout, - state, - spawn_defaults, - ) - .instrument(span) - .await + Self::execute_task_inner(durable, task, claim_timeout, fatal_on_lease_timeout) + .instrument(span) + .await } #[allow(clippy::too_many_arguments)] async fn execute_task_inner( - pool: PgPool, - queue_name: String, - registry: Arc>>, + durable: Durable, task: ClaimedTask, claim_timeout: Duration, fatal_on_lease_timeout: bool, - state: State, - spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -342,34 +294,30 @@ impl Worker { // Create task context let ctx = match TaskContext::create( - pool.clone(), - queue_name.clone(), + durable.clone_inner(), task.clone(), claim_timeout, lease_extender, - registry.clone(), - state.clone(), - spawn_defaults, ) .await { Ok(ctx) => ctx, Err(e) => { tracing::error!("Failed to create task context: {}", e); - Self::fail_run(&pool, &queue_name, task.task_id, task.run_id, &e.into()).await; + Self::fail_run(durable.pool(), durable.queue_name(), task.task_id, task.run_id, &e.into()).await; return; } }; // Look up handler - let registry = registry.read().await; + let registry = durable.registry().read().await; let handler = match registry.get(task.task_name.as_str()) { Some(h) => h.clone(), None => { tracing::error!("Unknown task: {}", task.task_name); Self::fail_run( - &pool, - &queue_name, + durable.pool(), + durable.queue_name(), task.task_id, task.run_id, &TaskError::Validation { @@ -387,6 +335,7 @@ impl Worker { // all the way through to the individual task steps let task_handle = tokio::spawn({ let params = task.params.clone(); + let state = durable.state().clone(); (async move { handler.execute(params, ctx, state).await }).instrument(Span::current()) }); let abort_handle = task_handle.abort_handle(); @@ -513,7 +462,7 @@ impl Worker { { outcome = "completed"; } - Self::complete_run(&pool, &queue_name, task.task_id, task.run_id, output).await; + Self::complete_run(durable.pool(), durable.queue_name(), task.task_id, task.run_id, output).await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name); @@ -541,7 +490,7 @@ impl Worker { outcome = "failed"; } tracing::error!("Task {} failed: {}", task_label, e); - Self::fail_run(&pool, &queue_name, task.task_id, task.run_id, e).await; + Self::fail_run(durable.pool(), durable.queue_name(), task.task_id, task.run_id, e).await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_failed( From 82794e58826afe6347ea38fe6f82cd12e6361332 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Jan 2026 14:14:46 -0500 Subject: [PATCH 2/6] Run clippy --- src/client.rs | 19 +++++++++++-------- src/worker.rs | 27 ++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8fae450..f032112 100644 --- a/src/client.rs +++ b/src/client.rs @@ -716,14 +716,17 @@ where // For now, we just manually construct a `Durable` with clones // of our fields. In the future, we may want to make `Durable` // implement `Clone`, or make an inner struct and wrap it in an `Arc` - Ok(Worker::start(Durable { - pool: self.pool.clone(), - owns_pool: self.owns_pool, - queue_name: self.queue_name.clone(), - spawn_defaults: self.spawn_defaults.clone(), - registry: self.registry.clone(), - state: self.state.clone(), - }, options) + Ok(Worker::start( + Durable { + pool: self.pool.clone(), + owns_pool: self.owns_pool, + queue_name: self.queue_name.clone(), + spawn_defaults: self.spawn_defaults.clone(), + registry: self.registry.clone(), + state: self.state.clone(), + }, + options, + ) .await) } diff --git a/src/worker.rs b/src/worker.rs index 00046f9..a4eca78 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -304,7 +304,14 @@ impl Worker { Ok(ctx) => ctx, Err(e) => { tracing::error!("Failed to create task context: {}", e); - Self::fail_run(durable.pool(), durable.queue_name(), task.task_id, task.run_id, &e.into()).await; + Self::fail_run( + durable.pool(), + durable.queue_name(), + task.task_id, + task.run_id, + &e.into(), + ) + .await; return; } }; @@ -462,7 +469,14 @@ impl Worker { { outcome = "completed"; } - Self::complete_run(durable.pool(), durable.queue_name(), task.task_id, task.run_id, output).await; + Self::complete_run( + durable.pool(), + durable.queue_name(), + task.task_id, + task.run_id, + output, + ) + .await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name); @@ -490,7 +504,14 @@ impl Worker { outcome = "failed"; } tracing::error!("Task {} failed: {}", task_label, e); - Self::fail_run(durable.pool(), durable.queue_name(), task.task_id, task.run_id, e).await; + Self::fail_run( + durable.pool(), + durable.queue_name(), + task.task_id, + task.run_id, + e, + ) + .await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_failed( From e09641b5c95b99031dbb6d7b07d5b9c635fd1e40 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Jan 2026 14:30:42 -0500 Subject: [PATCH 3/6] Fix telemetry --- src/context.rs | 2 +- src/worker.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/context.rs b/src/context.rs index a1a5710..b85f686 100644 --- a/src/context.rs +++ b/src/context.rs @@ -197,7 +197,7 @@ where { let duration = checkpoint_start.elapsed().as_secs_f64(); crate::telemetry::record_checkpoint_duration( - &self.queue_name, + self.durable.queue_name(), &self.task.task_name, duration, ); diff --git a/src/worker.rs b/src/worker.rs index a4eca78..f7742af 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -112,7 +112,7 @@ impl Worker { // Mark worker as active #[cfg(feature = "telemetry")] - crate::telemetry::set_worker_active(&queue_name, &worker_id, true); + crate::telemetry::set_worker_active(durable.queue_name(), &worker_id, true); // Semaphore limits concurrent task execution let semaphore = Arc::new(Semaphore::new(concurrency)); @@ -127,7 +127,7 @@ impl Worker { tracing::info!("Worker shutting down, waiting for in-flight tasks..."); #[cfg(feature = "telemetry")] - crate::telemetry::set_worker_active(&queue_name, &worker_id, false); + crate::telemetry::set_worker_active(durable.queue_name(), &worker_id, false); drop(done_tx); while done_rx.recv().await.is_some() {} @@ -192,8 +192,8 @@ impl Worker { tracing::instrument( level = "debug", name = "durable.worker.claim_tasks", - skip(pool), - fields(queue = %queue_name, worker_id = %worker_id, count = count) + skip(durable), + fields(queue = %durable.queue_name(), worker_id = %worker_id, count = count) ) )] async fn claim_tasks( @@ -225,9 +225,9 @@ impl Worker { #[cfg(feature = "telemetry")] { let duration = start.elapsed().as_secs_f64(); - crate::telemetry::record_task_claim_duration(queue_name, duration); + crate::telemetry::record_task_claim_duration(durable.queue_name(), duration); for _ in &tasks { - crate::telemetry::record_task_claimed(queue_name); + crate::telemetry::record_task_claimed(durable.queue_name()); } } @@ -285,7 +285,7 @@ impl Worker { #[cfg(feature = "telemetry")] let task_name = task.task_name.clone(); #[cfg(feature = "telemetry")] - let queue_name_for_metrics = queue_name.clone(); + let queue_name_for_metrics = durable.queue_name().to_string(); let start_time = Instant::now(); // Create lease extension channel - TaskContext will notify when lease is extended From 82a54283ae1ca51e245980d90462f55ec060f36a Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Jan 2026 16:01:27 -0500 Subject: [PATCH 4/6] Set parent_task_id --- src/client.rs | 3 ++ src/context.rs | 9 ++++- src/types.rs | 3 ++ tests/checkpoint_test.rs | 44 ++++++++++++----------- tests/common/helpers.rs | 1 + tests/common/tasks.rs | 13 +++---- tests/crash_test.rs | 45 ++++++++++++----------- tests/event_test.rs | 29 ++++++++------- tests/execution_test.rs | 14 ++++---- tests/fanout_test.rs | 47 ++++++++++++------------ tests/lock_order_test.rs | 11 +++--- tests/partition_test.rs | 22 ++++++------ tests/retry_test.rs | 42 ++++++++++++---------- tests/spawn_test.rs | 78 ++++++++++++++++------------------------ 14 files changed, 185 insertions(+), 176 deletions(-) diff --git a/src/client.rs b/src/client.rs index f032112..a98f522 100644 --- a/src/client.rs +++ b/src/client.rs @@ -24,6 +24,8 @@ struct SpawnOptionsDb<'a> { retry_strategy: Option<&'a RetryStrategy>, #[serde(skip_serializing_if = "Option::is_none")] cancellation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + parent_task_id: Option<&'a Uuid>, } /// Internal struct for serializing cancellation policy (only non-None fields). @@ -578,6 +580,7 @@ where .cancellation .as_ref() .and_then(CancellationPolicyDb::from_policy), + parent_task_id: options.parent_task_id.as_ref(), }; serde_json::to_value(db_options) } diff --git a/src/context.rs b/src/context.rs index b85f686..10993fe 100644 --- a/src/context.rs +++ b/src/context.rs @@ -629,7 +629,14 @@ where let spawned_task = self .durable - .spawn_by_name(task_name, params, options) + .spawn_by_name( + task_name, + params, + SpawnOptions { + parent_task_id: Some(self.task_id), + ..options + }, + ) .await .map_err(|e| TaskError::SubtaskSpawnFailed { name: task_name.to_string(), diff --git a/src/types.rs b/src/types.rs index 5319dcf..3cfaa07 100644 --- a/src/types.rs +++ b/src/types.rs @@ -134,6 +134,9 @@ pub struct SpawnOptions { /// Cancellation policy #[serde(skip_serializing_if = "Option::is_none")] pub cancellation: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) parent_task_id: Option, } /// Options for configuring a worker. diff --git a/tests/checkpoint_test.rs b/tests/checkpoint_test.rs index 661902a..3d43186 100644 --- a/tests/checkpoint_test.rs +++ b/tests/checkpoint_test.rs @@ -41,12 +41,13 @@ async fn test_checkpoint_prevents_step_reexecution(pool: PgPool) -> sqlx::Result StepCountingParams { fail_after_step2: true, }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await @@ -141,12 +142,13 @@ async fn test_deterministic_rand_preserved_on_retry(pool: PgPool) -> sqlx::Resul DeterministicReplayParams { fail_on_first_attempt: true, }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await @@ -205,12 +207,13 @@ async fn test_deterministic_now_preserved_on_retry(pool: PgPool) -> sqlx::Result DeterministicReplayParams { fail_on_first_attempt: true, }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await @@ -266,12 +269,13 @@ async fn test_deterministic_uuid7_preserved_on_retry(pool: PgPool) -> sqlx::Resu DeterministicReplayParams { fail_on_first_attempt: true, }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await diff --git a/tests/common/helpers.rs b/tests/common/helpers.rs index cfd5085..57aa0b3 100644 --- a/tests/common/helpers.rs +++ b/tests/common/helpers.rs @@ -220,6 +220,7 @@ pub async fn wait_for_task_terminal( let poll_interval = Duration::from_millis(50); while start.elapsed() < timeout { + eprintln!("Task: {:?}", get_task_state(pool, queue, task_id).await); if let Some(state) = get_task_state(pool, queue, task_id).await? && (state == "completed" || state == "failed" || state == "cancelled") { diff --git a/tests/common/tasks.rs b/tests/common/tasks.rs index ea65aa1..c760866 100644 --- a/tests/common/tasks.rs +++ b/tests/common/tasks.rs @@ -696,14 +696,11 @@ impl Task<()> for SpawnFailingChildTask { ) -> TaskResult { // Spawn with max_attempts=1 so child fails immediately without retries let handle: TaskHandle<()> = ctx - .spawn::( - "child", - (), - SpawnOptions { - max_attempts: Some(1), - ..Default::default() - }, - ) + .spawn::("child", (), { + let mut opts = SpawnOptions::default(); + opts.max_attempts = Some(1); + opts + }) .await?; // This should fail because child fails ctx.join(handle).await?; diff --git a/tests/crash_test.rs b/tests/crash_test.rs index 32b4822..bf186b0 100644 --- a/tests/crash_test.rs +++ b/tests/crash_test.rs @@ -41,12 +41,13 @@ async fn test_crash_mid_step_resumes_from_checkpoint(pool: PgPool) -> sqlx::Resu StepCountingParams { fail_after_step2: true, }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(3), - ..Default::default() + }); + opts.max_attempts = Some(3); + opts }, ) .await @@ -303,6 +304,7 @@ async fn test_heartbeat_prevents_lease_expiration(pool: PgPool) -> sqlx::Result< /// Uses SingleSpawnTask which already exists and spawns a child. #[sqlx::test(migrator = "MIGRATOR")] async fn test_spawn_idempotency_after_retry(pool: PgPool) -> sqlx::Result<()> { + tracing_subscriber::fmt::init(); use common::tasks::{DoubleTask, SingleSpawnParams, SingleSpawnTask}; let client = create_client(pool.clone(), "crash_spawn").await; @@ -381,12 +383,13 @@ async fn test_step_idempotency_after_retry(pool: PgPool) -> sqlx::Result<()> { StepCountingParams { fail_after_step2: false, // Don't fail, just complete }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await @@ -449,12 +452,13 @@ async fn test_cpu_bound_outlives_lease(pool: PgPool) -> sqlx::Result<()> { CpuBoundParams { duration_ms: 10000, // 10 seconds }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(3), - ..Default::default() + }); + opts.max_attempts = Some(3); + opts }, ) .await @@ -509,12 +513,13 @@ async fn test_slow_task_outlives_lease(pool: PgPool) -> sqlx::Result<()> { SlowNoHeartbeatParams { sleep_ms: 30000, // 30 seconds - much longer than lease }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(5), - ..Default::default() + }); + opts.max_attempts = Some(5); + opts }, ) .await diff --git a/tests/event_test.rs b/tests/event_test.rs index 0bc690f..36cd756 100644 --- a/tests/event_test.rs +++ b/tests/event_test.rs @@ -163,10 +163,11 @@ async fn test_event_timeout_triggers(pool: PgPool) -> sqlx::Result<()> { event_name: "never_emitted".to_string(), timeout_seconds: Some(1), // 1 second timeout }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::None), - max_attempts: Some(1), - ..Default::default() + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::None); + opts.max_attempts = Some(1); + opts }, ) .await @@ -284,12 +285,13 @@ async fn test_event_payload_preserved_on_retry(pool: PgPool) -> sqlx::Result<()> EventThenFailParams { event_name: "retry_event".to_string(), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await @@ -728,10 +730,11 @@ async fn test_event_timeout_error_payload(pool: PgPool) -> sqlx::Result<()> { event_name: "never_arrives".to_string(), timeout_seconds: Some(1), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::None), - max_attempts: Some(1), - ..Default::default() + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::None); + opts.max_attempts = Some(1); + opts }, ) .await diff --git a/tests/execution_test.rs b/tests/execution_test.rs index f6f86e3..da4beb7 100644 --- a/tests/execution_test.rs +++ b/tests/execution_test.rs @@ -610,14 +610,12 @@ async fn test_reserved_prefix_error_payload(pool: PgPool) -> sqlx::Result<()> { client.register::().await.unwrap(); let spawn_result = client - .spawn_with_options::( - (), - SpawnOptions { - retry_strategy: Some(RetryStrategy::None), - max_attempts: Some(1), - ..Default::default() - }, - ) + .spawn_with_options::((), { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::None); + opts.max_attempts = Some(1); + opts + }) .await .expect("Failed to spawn task"); diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index 9247516..37b5036 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -257,13 +257,11 @@ async fn test_child_failure_propagates_to_parent(pool: PgPool) -> sqlx::Result<( // Spawn parent task that will spawn a failing child // Use max_attempts=1 for both parent and child to avoid long retry waits let spawn_result = client - .spawn_with_options::( - (), - durable::SpawnOptions { - max_attempts: Some(1), - ..Default::default() - }, - ) + .spawn_with_options::((), { + let mut opts = durable::SpawnOptions::default(); + opts.max_attempts = Some(1); + opts + }) .await .expect("Failed to spawn task"); @@ -383,12 +381,13 @@ async fn test_cascade_cancel_when_parent_auto_cancelled_by_max_duration( SpawnSlowChildParams { child_sleep_ms: 10000, // 10 seconds }, - SpawnOptions { - cancellation: Some(CancellationPolicy { + { + let mut opts = SpawnOptions::default(); + opts.cancellation = Some(CancellationPolicy { max_pending_time: None, max_running_time: Some(Duration::from_secs(2)), // 2 seconds max duration - }), - ..Default::default() + }); + opts }, ) .await @@ -623,9 +622,10 @@ async fn test_join_cancelled_child_returns_child_cancelled_error(pool: PgPool) - JoinCancelledChildParams { child_sleep_ms: 10000, // 10 seconds - plenty of time to cancel }, - durable::SpawnOptions { - max_attempts: Some(1), - ..Default::default() + { + let mut opts = durable::SpawnOptions::default(); + opts.max_attempts = Some(1); + opts }, ) .await @@ -710,13 +710,11 @@ async fn test_child_failed_error_contains_message(pool: PgPool) -> sqlx::Result< // Spawn parent task with max_attempts=1 let spawn_result = client - .spawn_with_options::( - (), - durable::SpawnOptions { - max_attempts: Some(1), - ..Default::default() - }, - ) + .spawn_with_options::((), { + let mut opts = durable::SpawnOptions::default(); + opts.max_attempts = Some(1); + opts + }) .await .expect("Failed to spawn task"); @@ -789,9 +787,10 @@ async fn test_join_timeout_when_parent_claim_expires(pool: PgPool) -> sqlx::Resu SpawnSlowChildParams { child_sleep_ms: 30000, // 30 seconds - much longer than claim_timeout }, - durable::SpawnOptions { - max_attempts: Some(1), - ..Default::default() + { + let mut opts = durable::SpawnOptions::default(); + opts.max_attempts = Some(1); + opts }, ) .await diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs index 9056154..be8cbb6 100644 --- a/tests/lock_order_test.rs +++ b/tests/lock_order_test.rs @@ -89,12 +89,13 @@ async fn test_fail_run_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { FailingParams { error_message: "intentional failure".to_string(), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await diff --git a/tests/partition_test.rs b/tests/partition_test.rs index e02abfc..0a44298 100644 --- a/tests/partition_test.rs +++ b/tests/partition_test.rs @@ -35,12 +35,13 @@ async fn test_db_connection_lost_during_checkpoint(pool: PgPool) -> sqlx::Result StepCountingParams { fail_after_step2: true, }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(3), - ..Default::default() + }); + opts.max_attempts = Some(3); + opts }, ) .await @@ -100,12 +101,13 @@ async fn test_stale_worker_checkpoint_rejected(pool: PgPool) -> sqlx::Result<()> SlowNoHeartbeatParams { sleep_ms: 30000, // 30 seconds }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(5), - ..Default::default() + }); + opts.max_attempts = Some(5); + opts }, ) .await diff --git a/tests/retry_test.rs b/tests/retry_test.rs index fbc2ca8..937e5d6 100644 --- a/tests/retry_test.rs +++ b/tests/retry_test.rs @@ -34,10 +34,11 @@ async fn test_retry_strategy_none_no_retry(pool: PgPool) -> sqlx::Result<()> { FailingParams { error_message: "intentional failure".to_string(), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::None), - max_attempts: Some(1), - ..Default::default() + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::None); + opts.max_attempts = Some(1); + opts }, ) .await @@ -87,12 +88,13 @@ async fn test_retry_strategy_fixed_delay(pool: PgPool) -> sqlx::Result<()> { FailingParams { error_message: "intentional failure".to_string(), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(5), - }), - max_attempts: Some(2), - ..Default::default() + }); + opts.max_attempts = Some(2); + opts }, ) .await @@ -166,14 +168,15 @@ async fn test_retry_strategy_exponential_backoff(pool: PgPool) -> sqlx::Result<( FailingParams { error_message: "intentional failure".to_string(), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Exponential { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Exponential { base_delay: Duration::from_secs(2), factor: 2.0, max_backoff: Duration::from_secs(100), - }), - max_attempts: Some(3), - ..Default::default() + }); + opts.max_attempts = Some(3); + opts }, ) .await @@ -249,12 +252,13 @@ async fn test_max_attempts_honored(pool: PgPool) -> sqlx::Result<()> { FailingParams { error_message: "intentional failure".to_string(), }, - SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { + { + let mut opts = SpawnOptions::default(); + opts.retry_strategy = Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0), - }), - max_attempts: Some(3), - ..Default::default() + }); + opts.max_attempts = Some(3); + opts }, ) .await diff --git a/tests/spawn_test.rs b/tests/spawn_test.rs index 247f840..41635ba 100644 --- a/tests/spawn_test.rs +++ b/tests/spawn_test.rs @@ -82,10 +82,8 @@ async fn test_spawn_with_custom_max_attempts(pool: PgPool) -> sqlx::Result<()> { client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); - let options = SpawnOptions { - max_attempts: Some(10), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.max_attempts = Some(10); let result = client .spawn_with_options::( @@ -109,10 +107,8 @@ async fn test_spawn_with_retry_strategy_none(pool: PgPool) -> sqlx::Result<()> { client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); - let options = SpawnOptions { - retry_strategy: Some(RetryStrategy::None), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.retry_strategy = Some(RetryStrategy::None); let result = client .spawn_with_options::( @@ -135,12 +131,10 @@ async fn test_spawn_with_retry_strategy_fixed(pool: PgPool) -> sqlx::Result<()> client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); - let options = SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { - base_delay: Duration::from_secs(10), - }), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.retry_strategy = Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(10), + }); let result = client .spawn_with_options::( @@ -163,14 +157,12 @@ async fn test_spawn_with_retry_strategy_exponential(pool: PgPool) -> sqlx::Resul client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); - let options = SpawnOptions { - retry_strategy: Some(RetryStrategy::Exponential { - base_delay: Duration::from_secs(5), - factor: 2.0, - max_backoff: Duration::from_secs(300), - }), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.retry_strategy = Some(RetryStrategy::Exponential { + base_delay: Duration::from_secs(5), + factor: 2.0, + max_backoff: Duration::from_secs(300), + }); let result = client .spawn_with_options::( @@ -197,10 +189,8 @@ async fn test_spawn_with_headers(pool: PgPool) -> sqlx::Result<()> { headers.insert("correlation_id".to_string(), serde_json::json!("abc-123")); headers.insert("priority".to_string(), serde_json::json!(5)); - let options = SpawnOptions { - headers: Some(headers), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.headers = Some(headers); let result = client .spawn_with_options::( @@ -223,13 +213,11 @@ async fn test_spawn_with_cancellation_policy(pool: PgPool) -> sqlx::Result<()> { client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); - let options = SpawnOptions { - cancellation: Some(CancellationPolicy { - max_pending_time: Some(Duration::from_secs(60)), - max_running_time: Some(Duration::from_secs(300)), - }), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.cancellation = Some(CancellationPolicy { + max_pending_time: Some(Duration::from_secs(60)), + max_running_time: Some(Duration::from_secs(300)), + }); let result = client .spawn_with_options::( @@ -307,13 +295,11 @@ async fn test_spawn_by_name_with_options(pool: PgPool) -> sqlx::Result<()> { "message": "value" }); - let options = SpawnOptions { - max_attempts: Some(3), - retry_strategy: Some(RetryStrategy::Fixed { - base_delay: Duration::from_secs(5), - }), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.max_attempts = Some(3); + options.retry_strategy = Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(5), + }); let result = client .spawn_by_name("echo", params, options) @@ -535,10 +521,8 @@ async fn test_spawn_rejects_reserved_header_prefix(pool: PgPool) -> sqlx::Result let mut headers = HashMap::new(); headers.insert("durable::custom".to_string(), serde_json::json!("value")); - let options = SpawnOptions { - headers: Some(headers), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.headers = Some(headers); let result = client .spawn_with_options::( @@ -572,10 +556,8 @@ async fn test_spawn_allows_non_reserved_headers(pool: PgPool) -> sqlx::Result<() headers.insert("durable".to_string(), serde_json::json!("no colons")); headers.insert("durable:single".to_string(), serde_json::json!("one colon")); - let options = SpawnOptions { - headers: Some(headers), - ..Default::default() - }; + let mut options = SpawnOptions::default(); + options.headers = Some(headers); let result = client .spawn_with_options::( From e4d77834801835f224cbc912376dba3f59376b14 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Jan 2026 17:38:40 -0500 Subject: [PATCH 5/6] Fix review comments --- src/client.rs | 16 +--------------- src/error.rs | 2 +- tests/common/helpers.rs | 1 - 3 files changed, 2 insertions(+), 17 deletions(-) diff --git a/src/client.rs b/src/client.rs index a98f522..6d148ec 100644 --- a/src/client.rs +++ b/src/client.rs @@ -716,21 +716,7 @@ where }); } - // For now, we just manually construct a `Durable` with clones - // of our fields. In the future, we may want to make `Durable` - // implement `Clone`, or make an inner struct and wrap it in an `Arc` - Ok(Worker::start( - Durable { - pool: self.pool.clone(), - owns_pool: self.owns_pool, - queue_name: self.queue_name.clone(), - spawn_defaults: self.spawn_defaults.clone(), - registry: self.registry.clone(), - state: self.state.clone(), - }, - options, - ) - .await) + Ok(Worker::start(self.clone_inner(), options).await) } /// Close the client. Closes the pool if owned. diff --git a/src/error.rs b/src/error.rs index c4d6cb2..7e97013 100644 --- a/src/error.rs +++ b/src/error.rs @@ -228,7 +228,7 @@ pub fn serialize_task_error(err: &TaskError) -> JsonValue { serde_json::json!({ "name": "SubtaskSpawnFailed", "message": error.to_string(), - "name": name, + "subtask_name": name, }) } TaskError::ChildFailed { step_name, message } => { diff --git a/tests/common/helpers.rs b/tests/common/helpers.rs index 57aa0b3..cfd5085 100644 --- a/tests/common/helpers.rs +++ b/tests/common/helpers.rs @@ -220,7 +220,6 @@ pub async fn wait_for_task_terminal( let poll_interval = Duration::from_millis(50); while start.elapsed() < timeout { - eprintln!("Task: {:?}", get_task_state(pool, queue, task_id).await); if let Some(state) = get_task_state(pool, queue, task_id).await? && (state == "completed" || state == "failed" || state == "cancelled") { From e8e42f3cf978d21fa907f3274c113bbe86d9a350 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Fri, 23 Jan 2026 10:11:05 -0500 Subject: [PATCH 6/6] Mark owns_pool as false when client is cloesd --- src/client.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/client.rs b/src/client.rs index 6d148ec..370093d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ use serde_json::Value as JsonValue; use sqlx::{Executor, PgPool, Postgres}; use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use tokio::sync::RwLock; use uuid::Uuid; @@ -110,7 +111,7 @@ where State: Clone + Send + Sync + 'static, { pool: PgPool, - owns_pool: bool, + owns_pool: AtomicBool, queue_name: String, spawn_defaults: SpawnDefaults, registry: Arc>>, @@ -123,9 +124,17 @@ impl Durable { /// Currently, we only allow cloning with in the crate /// via this method pub(crate) fn clone_inner(&self) -> Durable { + // When we clone a durable client, mark *ourself* as no longer owning the pool + // This will cause `Durable.close()` to be a no-op, since something else could + // still be using the pool. + // sqlx itself will still close the pool when the last reference to it is dropped. + // At the moment, we only call `clone_inner` when spawning a worker, which has its own + // `shutdown()` method. + self.owns_pool.store(false, Ordering::Relaxed); Durable { pool: self.pool.clone(), - owns_pool: self.owns_pool, + // A clone of a durable client never owns the pool, so we set this to false + owns_pool: AtomicBool::new(false), queue_name: self.queue_name.clone(), spawn_defaults: self.spawn_defaults.clone(), registry: self.registry.clone(), @@ -275,7 +284,7 @@ impl DurableBuilder { Ok(Durable { pool, - owns_pool, + owns_pool: AtomicBool::new(owns_pool), queue_name: self.queue_name, spawn_defaults: self.spawn_defaults, registry: Arc::new(RwLock::new(HashMap::new())), @@ -721,7 +730,7 @@ where /// Close the client. Closes the pool if owned. pub async fn close(self) { - if self.owns_pool { + if self.owns_pool.load(Ordering::Relaxed) { self.pool.close().await; } }