diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index 269f267..0dd16ab 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -84,7 +84,7 @@ public class Client { return } try await self.onComplete(completeResponse, self) - case .unknown: + default: try await self.error(.invalidType()) } } diff --git a/Sources/GraphQLWS/GraphQLWSError.swift b/Sources/GraphQLWS/GraphQLWSError.swift index b5fce64..8036e91 100644 --- a/Sources/GraphQLWS/GraphQLWSError.swift +++ b/Sources/GraphQLWS/GraphQLWSError.swift @@ -60,14 +60,14 @@ struct GraphQLWSError: Error { static func invalidRequestFormat(messageType: RequestMessageType) -> Self { return self.init( - "Request message doesn't match '\(messageType.rawValue)' JSON format", + "Request message doesn't match '\(messageType.type.rawValue)' JSON format", code: .invalidRequestFormat ) } static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { return self.init( - "Response message doesn't match '\(messageType.rawValue)' JSON format", + "Response message doesn't match '\(messageType.type.rawValue)' JSON format", code: .invalidResponseFormat ) } diff --git a/Sources/GraphQLWS/Requests.swift b/Sources/GraphQLWS/Requests.swift index 21f5abd..f86d69b 100644 --- a/Sources/GraphQLWS/Requests.swift +++ b/Sources/GraphQLWS/Requests.swift @@ -1,59 +1,124 @@ import Foundation import GraphQL -/// We also require that an 'authToken' field is provided in the 'payload' during the connection -/// init message. For example: -/// ``` -/// { -/// "type": 'connection_init', -/// "payload": { -/// "authToken": "eyJhbGciOiJIUz..." -/// } -/// } -/// ``` - /// A general request. This object's type is used to triage to other, more specific request objects. -struct Request: Equatable, JsonEncodable { - let type: RequestMessageType +public struct Request: Equatable, JsonEncodable { + public let type: RequestMessageType } /// A websocket `connection_init` request from the client to the server public struct ConnectionInitRequest: Equatable, JsonEncodable { - var type = RequestMessageType.GQL_CONNECTION_INIT - let payload: InitPayload + public let type: RequestMessageType = .GQL_CONNECTION_INIT + public let payload: InitPayload + + public init(payload: InitPayload) { + self.payload = payload + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_INIT { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.GQL_CONNECTION_INIT.type)`" + )) + } + payload = try container.decode(InitPayload.self, forKey: .payload) + } } /// A websocket `start` request from the client to the server -struct StartRequest: Equatable, JsonEncodable { - var type = RequestMessageType.GQL_START - let payload: GraphQLRequest - let id: String +public struct StartRequest: Equatable, JsonEncodable { + public let type: RequestMessageType = .GQL_START + public let payload: GraphQLRequest + public let id: String + + public init(payload: GraphQLRequest, id: String) { + self.payload = payload + self.id = id + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_START { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.GQL_START.type)`" + )) + } + payload = try container.decode(GraphQLRequest.self, forKey: .payload) + id = try container.decode(String.self, forKey: .id) + } } /// A websocket `stop` request from the client to the server -struct StopRequest: Equatable, JsonEncodable { - var type = RequestMessageType.GQL_STOP - let id: String +public struct StopRequest: Equatable, JsonEncodable { + public let type: RequestMessageType = .GQL_STOP + public let id: String + + public init(id: String) { + self.id = id + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.GQL_STOP.type)`" + )) + } + id = try container.decode(String.self, forKey: .id) + } } /// A websocket `connection_terminate` request from the client to the server -struct ConnectionTerminateRequest: Equatable, JsonEncodable { - var type = RequestMessageType.GQL_CONNECTION_TERMINATE +public struct ConnectionTerminateRequest: Equatable, JsonEncodable { + public let type: RequestMessageType = .GQL_CONNECTION_TERMINATE + + public init() {} + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.GQL_CONNECTION_TERMINATE.type)`" + )) + } + } } /// The supported websocket request message types from the client to the server -enum RequestMessageType: String, Codable { - case GQL_CONNECTION_INIT = "connection_init" - case GQL_START = "start" - case GQL_STOP = "stop" - case GQL_CONNECTION_TERMINATE = "connection_terminate" - case unknown - - init(from decoder: Decoder) throws { - guard let value = try? decoder.singleValueContainer().decode(String.self) else { - self = .unknown - return - } - self = RequestMessageType(rawValue: value) ?? .unknown +public struct RequestMessageType: Equatable, Codable, Sendable { + // This is implemented as a struct with only public static properties, backed by an internal enum + // in order to grow the list of accepted response types in a non-breaking way. + + let type: RequestType + + init(type: RequestType) { + self.type = type + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + type = try container.decode(RequestType.self) + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(type) + } + + public static let GQL_CONNECTION_INIT: Self = .init(type: .GQL_CONNECTION_INIT) + public static let GQL_START: Self = .init(type: .GQL_START) + public static let GQL_STOP: Self = .init(type: .GQL_STOP) + public static let GQL_CONNECTION_TERMINATE: Self = .init(type: .GQL_CONNECTION_TERMINATE) + + enum RequestType: String, Codable { + case GQL_CONNECTION_INIT = "connection_init" + case GQL_START = "start" + case GQL_STOP = "stop" + case GQL_CONNECTION_TERMINATE = "connection_terminate" } } diff --git a/Sources/GraphQLWS/Responses.swift b/Sources/GraphQLWS/Responses.swift index 525fa17..f4e5511 100644 --- a/Sources/GraphQLWS/Responses.swift +++ b/Sources/GraphQLWS/Responses.swift @@ -3,69 +3,120 @@ import GraphQL /// A general response. This object's type is used to triage to other, more specific response objects. public struct Response: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType } /// A websocket `connection_ack` response from the server to the client public struct ConnectionAckResponse: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType = .GQL_CONNECTION_ACK public let payload: [String: Map]? - init(_ payload: [String: Map]? = nil) { - type = .GQL_CONNECTION_ACK + public init(payload: [String: Map]?) { self.payload = payload } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_ACK { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_ACK.type)`" + )) + } + payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) + } } /// A websocket `connection_error` response from the server to the client public struct ConnectionErrorResponse: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType = .GQL_CONNECTION_ERROR public let payload: [String: Map]? - init(_ payload: [String: Map]? = nil) { - type = .GQL_CONNECTION_ERROR + public init(payload: [String: Map]? = nil) { self.payload = payload } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_ERROR { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_ERROR.type)`" + )) + } + payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) + } } /// A websocket `ka` response from the server to the client public struct ConnectionKeepAliveResponse: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType = .GQL_CONNECTION_KEEP_ALIVE public let payload: [String: Map]? - init(_ payload: [String: Map]? = nil) { - type = .GQL_CONNECTION_KEEP_ALIVE + public init(payload: [String: Map]?) { self.payload = payload } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_KEEP_ALIVE { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_KEEP_ALIVE.type)`" + )) + } + payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) + } } /// A websocket `data` response from the server to the client public struct DataResponse: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType = .GQL_DATA public let payload: GraphQLResult? public let id: String - init(_ payload: GraphQLResult? = nil, id: String) { - type = .GQL_DATA + public init(payload: GraphQLResult?, id: String) { self.payload = payload self.id = id } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_DATA { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_DATA.type)`" + )) + } + payload = try container.decodeIfPresent(GraphQLResult.self, forKey: .payload) + id = try container.decode(String.self, forKey: .id) + } } /// A websocket `complete` response from the server to the client public struct CompleteResponse: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType = .GQL_COMPLETE public let id: String - init(id: String) { - type = .GQL_COMPLETE + public init(id: String) { self.id = id } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_COMPLETE { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_COMPLETE.type)`" + )) + } + id = try container.decode(String.self, forKey: .id) + } } /// A websocket `error` response from the server to the client public struct ErrorResponse: Equatable, JsonEncodable { - let type: ResponseMessageType + public let type: ResponseMessageType = .GQL_ERROR public let payload: [GraphQLError] public let id: String @@ -78,28 +129,58 @@ public struct ErrorResponse: Equatable, JsonEncodable { return GraphQLError(error) } } - type = .GQL_ERROR payload = graphQLErrors self.id = id } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: Self.CodingKeys.self) + if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_ERROR { + throw DecodingError.dataCorrupted(.init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_ERROR.type)`" + )) + } + payload = try container.decode([GraphQLError].self, forKey: .payload) + id = try container.decode(String.self, forKey: .id) + } } /// The supported websocket response message types from the server to the client -enum ResponseMessageType: String, Codable { - case GQL_CONNECTION_ACK = "connection_ack" - case GQL_CONNECTION_ERROR = "connection_error" - case GQL_CONNECTION_KEEP_ALIVE = "ka" - case GQL_DATA = "data" - case GQL_ERROR = "error" - case GQL_COMPLETE = "complete" - case unknown - - init(from decoder: Decoder) throws { - guard let value = try? decoder.singleValueContainer().decode(String.self) else { - self = .unknown - return - } - self = ResponseMessageType(rawValue: value) ?? .unknown +public struct ResponseMessageType: Equatable, Codable, Sendable { + // This is implemented as a struct with only public static properties, backed by an internal enum + // in order to grow the list of accepted response types in a non-breaking way. + + let type: ResponseType + + init(type: ResponseType) { + self.type = type + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + type = try container.decode(ResponseType.self) + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(type) + } + + public static let GQL_CONNECTION_ACK: Self = .init(type: .GQL_CONNECTION_ACK) + public static let GQL_CONNECTION_ERROR: Self = .init(type: .GQL_CONNECTION_ERROR) + public static let GQL_CONNECTION_KEEP_ALIVE: Self = .init(type: .GQL_CONNECTION_KEEP_ALIVE) + public static let GQL_DATA: Self = .init(type: .GQL_DATA) + public static let GQL_ERROR: Self = .init(type: .GQL_ERROR) + public static let GQL_COMPLETE: Self = .init(type: .GQL_COMPLETE) + + enum ResponseType: String, Codable { + case GQL_CONNECTION_ACK = "connection_ack" + case GQL_CONNECTION_ERROR = "connection_error" + case GQL_CONNECTION_KEEP_ALIVE = "ka" + case GQL_DATA = "data" + case GQL_ERROR = "error" + case GQL_COMPLETE = "complete" } } diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 24e5590..b6bd98f 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -95,7 +95,7 @@ public class Server< return } try await self.onConnectionTerminate(connectionTerminateRequest, messenger) - case .unknown: + default: try await self.error(.invalidType()) } } @@ -229,7 +229,7 @@ public class Server< private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } try await messenger.send( - ConnectionAckResponse(payload).toJSON(encoder) + ConnectionAckResponse(payload: payload).toJSON(encoder) ) } @@ -237,7 +237,7 @@ public class Server< private func sendConnectionError(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } try await messenger.send( - ConnectionErrorResponse(payload).toJSON(encoder) + ConnectionErrorResponse(payload: payload).toJSON(encoder) ) } @@ -245,7 +245,7 @@ public class Server< private func sendConnectionKeepAlive(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } try await messenger.send( - ConnectionKeepAliveResponse(payload).toJSON(encoder) + ConnectionKeepAliveResponse(payload: payload).toJSON(encoder) ) } @@ -254,7 +254,7 @@ public class Server< guard let messenger = messenger else { return } try await messenger.send( DataResponse( - payload, + payload: payload, id: id ).toJSON(encoder) ) diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index 0ab46da..37457b7 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -3,7 +3,7 @@ import Foundation import GraphQL import XCTest -@testable import GraphQLWS +import GraphQLWS class GraphqlWsTests: XCTestCase { var clientMessenger: TestMessenger!