diff --git a/src/client.rs b/src/client.rs index 370093d..4984d6f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use serde::Serialize; +use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use sqlx::{Executor, PgPool, Postgres}; use std::collections::HashMap; @@ -684,7 +684,18 @@ where #[cfg(feature = "telemetry")] tracing::Span::current().record("queue", queue); - let payload_json = serde_json::to_value(payload)?; + let inner_payload_json = serde_json::to_value(payload)?; + + #[allow(unused_mut)] // mut is needed when telemetry feature is enabled + let mut payload_wrapper = EventPayloadWrapper { + inner: inner_payload_json, + trace_context: HashMap::new(), + }; + + #[cfg(feature = "telemetry")] + crate::telemetry::inject_trace_context(&mut payload_wrapper.trace_context); + + let payload_json = serde_json::to_value(payload_wrapper)?; let query = "SELECT durable.emit_event($1, $2, $3)"; sqlx::query(query) @@ -735,3 +746,13 @@ where } } } + +/// A wrapper struct that we use in 'emit_event' +/// This allows us to attach extra data (e.g. a trace context) +#[derive(Serialize, Deserialize)] + +pub(crate) struct EventPayloadWrapper { + pub inner: JsonValue, + // Populated by 'inject_trace_context' + pub trace_context: HashMap, +} diff --git a/src/context.rs b/src/context.rs index 10993fe..1e776bb 100644 --- a/src/context.rs +++ b/src/context.rs @@ -6,6 +6,7 @@ use std::time::Duration; use uuid::Uuid; use crate::Durable; +use crate::client::EventPayloadWrapper; use crate::error::{ControlFlow, TaskError, TaskResult}; use crate::task::Task; use crate::types::{ @@ -351,7 +352,7 @@ where // Check cache for already-received event if let Some(cached) = self.checkpoint_cache.get(&checkpoint_name) { - return Ok(serde_json::from_value(cached.clone())?); + return Self::process_event_payload_wrapper(cached.clone()); } // Check if we were woken by this event but it timed out (null payload) @@ -383,10 +384,30 @@ where } // Event arrived - cache and return - let payload = result.payload.unwrap_or(JsonValue::Null); + let payload_wrapper_json = result.payload.unwrap_or(JsonValue::Null); self.checkpoint_cache - .insert(checkpoint_name, payload.clone()); - Ok(serde_json::from_value(payload)?) + .insert(checkpoint_name, payload_wrapper_json.clone()); + + Self::process_event_payload_wrapper(payload_wrapper_json) + } + + fn process_event_payload_wrapper(value: JsonValue) -> TaskResult { + let payload_wrapper: EventPayloadWrapper = serde_json::from_value(value)?; + let payload_inner = payload_wrapper.inner; + #[cfg(feature = "telemetry")] + { + use opentelemetry::KeyValue; + use opentelemetry::trace::TraceContextExt; + use tracing_opentelemetry::OpenTelemetrySpanExt; + + let context = crate::telemetry::extract_trace_context(&payload_wrapper.trace_context); + tracing::Span::current().add_link_with_attributes( + context.span().span_context().clone(), + vec![KeyValue::new("sentry.link.type", "previous_trace")], + ); + } + + Ok(serde_json::from_value(payload_inner)?) } /// Emit an event to this task's queue. @@ -404,22 +425,13 @@ where ) )] pub async fn emit_event(&self, event_name: &str, payload: &T) -> TaskResult<()> { - if event_name.is_empty() { - return Err(TaskError::Validation { - message: "event_name must be non-empty".to_string(), - }); - } - - let payload_json = serde_json::to_value(payload)?; - let query = "SELECT durable.emit_event($1, $2, $3)"; - sqlx::query(query) - .bind(self.durable.queue_name()) - .bind(event_name) - .bind(&payload_json) - .execute(self.durable.pool()) - .await?; - - Ok(()) + self.durable + .emit_event(event_name, payload, None) + .await + .map_err(|e| TaskError::EmitEventFailed { + event_name: event_name.to_string(), + error: e, + }) } /// Extend the task's lease to prevent timeout. @@ -709,6 +721,8 @@ where let query = "SELECT should_suspend, payload FROM durable.await_event($1, $2, $3, $4, $5, $6)"; + // This deliberately does *not* use our `await_event` wrapper, since this event is emitted + // by the durable sql itself (and does not use our `EventPayloadWrapper`) let result: AwaitEventResult = sqlx::query_as(query) .bind(self.durable.queue_name()) .bind(self.task_id) diff --git a/src/error.rs b/src/error.rs index 7e97013..baeaff6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -77,6 +77,13 @@ pub enum TaskError { #[error("failed to spawn subtask `{name}`: {error}")] SubtaskSpawnFailed { name: String, error: DurableError }, + /// Error occurred while trying to emit an event. + #[error("failed to emit event `{event_name}`: {error}")] + EmitEventFailed { + event_name: String, + error: DurableError, + }, + /// A child task failed. /// /// Returned by [`TaskContext::join`](crate::TaskContext::join) when the child @@ -231,6 +238,13 @@ pub fn serialize_task_error(err: &TaskError) -> JsonValue { "subtask_name": name, }) } + TaskError::EmitEventFailed { event_name, error } => { + serde_json::json!({ + "name": "EmitEventFailed", + "message": error.to_string(), + "event_name": event_name, + }) + } TaskError::ChildFailed { step_name, message } => { serde_json::json!({ "name": "ChildFailed", diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs index be8cbb6..f541862 100644 --- a/tests/lock_order_test.rs +++ b/tests/lock_order_test.rs @@ -258,9 +258,9 @@ async fn test_emit_event_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { "Task should be sleeping waiting for event" ); - // Emit the event + // Emit the event - payload must use EventPayloadWrapper format let emit_query = AssertSqlSafe( - "SELECT durable.emit_event('lock_emit', 'test_event', '\"hello\"'::jsonb)".to_string(), + "SELECT durable.emit_event('lock_emit', 'test_event', '{\"inner\": \"hello\", \"trace_context\": {}}'::jsonb)".to_string(), ); sqlx::query(emit_query).execute(&pool).await?; @@ -328,13 +328,13 @@ async fn test_concurrent_emit_and_cancel(pool: PgPool) -> sqlx::Result<()> { ); } - // Cancel one task while emitting the event + // Cancel one task while emitting the event - payload must use EventPayloadWrapper format let cancel_task_id = task_ids[0]; let emit_handle = tokio::spawn({ let test_pool = test_pool.clone(); async move { let emit_query = AssertSqlSafe( - "SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '\"wakeup\"'::jsonb)" + "SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '{\"inner\": \"wakeup\", \"trace_context\": {}}'::jsonb)" .to_string(), ); sqlx::query(emit_query).execute(&test_pool).await