Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use serde::Serialize;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use sqlx::{Executor, PgPool, Postgres};
use std::collections::HashMap;
Expand Down Expand Up @@ -684,7 +684,18 @@ where
#[cfg(feature = "telemetry")]
tracing::Span::current().record("queue", queue);

let payload_json = serde_json::to_value(payload)?;
let inner_payload_json = serde_json::to_value(payload)?;

#[allow(unused_mut)] // mut is needed when telemetry feature is enabled
let mut payload_wrapper = EventPayloadWrapper {
inner: inner_payload_json,
trace_context: HashMap::new(),
};

#[cfg(feature = "telemetry")]
crate::telemetry::inject_trace_context(&mut payload_wrapper.trace_context);

let payload_json = serde_json::to_value(payload_wrapper)?;

let query = "SELECT durable.emit_event($1, $2, $3)";
sqlx::query(query)
Expand Down Expand Up @@ -735,3 +746,13 @@ where
}
}
}

/// A wrapper struct that we use in 'emit_event'
/// This allows us to attach extra data (e.g. a trace context)
#[derive(Serialize, Deserialize)]

pub(crate) struct EventPayloadWrapper {
pub inner: JsonValue,
// Populated by 'inject_trace_context'
pub trace_context: HashMap<String, JsonValue>,
}
54 changes: 34 additions & 20 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Duration;
use uuid::Uuid;

use crate::Durable;
use crate::client::EventPayloadWrapper;
use crate::error::{ControlFlow, TaskError, TaskResult};
use crate::task::Task;
use crate::types::{
Expand Down Expand Up @@ -351,7 +352,7 @@ where

// Check cache for already-received event
if let Some(cached) = self.checkpoint_cache.get(&checkpoint_name) {
return Ok(serde_json::from_value(cached.clone())?);
return Self::process_event_payload_wrapper(cached.clone());
}

// Check if we were woken by this event but it timed out (null payload)
Expand Down Expand Up @@ -383,10 +384,30 @@ where
}

// Event arrived - cache and return
let payload = result.payload.unwrap_or(JsonValue::Null);
let payload_wrapper_json = result.payload.unwrap_or(JsonValue::Null);
self.checkpoint_cache
.insert(checkpoint_name, payload.clone());
Ok(serde_json::from_value(payload)?)
.insert(checkpoint_name, payload_wrapper_json.clone());

Self::process_event_payload_wrapper(payload_wrapper_json)
}

fn process_event_payload_wrapper<T: DeserializeOwned>(value: JsonValue) -> TaskResult<T> {
let payload_wrapper: EventPayloadWrapper = serde_json::from_value(value)?;
let payload_inner = payload_wrapper.inner;
#[cfg(feature = "telemetry")]
{
use opentelemetry::KeyValue;
use opentelemetry::trace::TraceContextExt;
use tracing_opentelemetry::OpenTelemetrySpanExt;

let context = crate::telemetry::extract_trace_context(&payload_wrapper.trace_context);
tracing::Span::current().add_link_with_attributes(
context.span().span_context().clone(),
vec![KeyValue::new("sentry.link.type", "previous_trace")],
);
}

Ok(serde_json::from_value(payload_inner)?)
}

/// Emit an event to this task's queue.
Expand All @@ -404,22 +425,13 @@ where
)
)]
pub async fn emit_event<T: Serialize>(&self, event_name: &str, payload: &T) -> TaskResult<()> {
if event_name.is_empty() {
return Err(TaskError::Validation {
message: "event_name must be non-empty".to_string(),
});
}

let payload_json = serde_json::to_value(payload)?;
let query = "SELECT durable.emit_event($1, $2, $3)";
sqlx::query(query)
.bind(self.durable.queue_name())
.bind(event_name)
.bind(&payload_json)
.execute(self.durable.pool())
.await?;

Ok(())
self.durable
.emit_event(event_name, payload, None)
.await
.map_err(|e| TaskError::EmitEventFailed {
event_name: event_name.to_string(),
error: e,
})
}

/// Extend the task's lease to prevent timeout.
Expand Down Expand Up @@ -709,6 +721,8 @@ where
let query = "SELECT should_suspend, payload
FROM durable.await_event($1, $2, $3, $4, $5, $6)";

// This deliberately does *not* use our `await_event` wrapper, since this event is emitted
// by the durable sql itself (and does not use our `EventPayloadWrapper`)
let result: AwaitEventResult = sqlx::query_as(query)
.bind(self.durable.queue_name())
.bind(self.task_id)
Expand Down
14 changes: 14 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ pub enum TaskError {
#[error("failed to spawn subtask `{name}`: {error}")]
SubtaskSpawnFailed { name: String, error: DurableError },

/// Error occurred while trying to emit an event.
#[error("failed to emit event `{event_name}`: {error}")]
EmitEventFailed {
event_name: String,
error: DurableError,
},

/// A child task failed.
///
/// Returned by [`TaskContext::join`](crate::TaskContext::join) when the child
Expand Down Expand Up @@ -231,6 +238,13 @@ pub fn serialize_task_error(err: &TaskError) -> JsonValue {
"subtask_name": name,
})
}
TaskError::EmitEventFailed { event_name, error } => {
serde_json::json!({
"name": "EmitEventFailed",
"message": error.to_string(),
"event_name": event_name,
})
}
TaskError::ChildFailed { step_name, message } => {
serde_json::json!({
"name": "ChildFailed",
Expand Down
8 changes: 4 additions & 4 deletions tests/lock_order_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ async fn test_emit_event_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> {
"Task should be sleeping waiting for event"
);

// Emit the event
// Emit the event - payload must use EventPayloadWrapper format
let emit_query = AssertSqlSafe(
"SELECT durable.emit_event('lock_emit', 'test_event', '\"hello\"'::jsonb)".to_string(),
"SELECT durable.emit_event('lock_emit', 'test_event', '{\"inner\": \"hello\", \"trace_context\": {}}'::jsonb)".to_string(),
);
sqlx::query(emit_query).execute(&pool).await?;

Expand Down Expand Up @@ -328,13 +328,13 @@ async fn test_concurrent_emit_and_cancel(pool: PgPool) -> sqlx::Result<()> {
);
}

// Cancel one task while emitting the event
// Cancel one task while emitting the event - payload must use EventPayloadWrapper format
let cancel_task_id = task_ids[0];
let emit_handle = tokio::spawn({
let test_pool = test_pool.clone();
async move {
let emit_query = AssertSqlSafe(
"SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '\"wakeup\"'::jsonb)"
"SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '{\"inner\": \"wakeup\", \"trace_context\": {}}'::jsonb)"
.to_string(),
);
sqlx::query(emit_query).execute(&test_pool).await
Expand Down