(_ identifier: WSEventIdentifier, _ handler: @escaping (WSClient, P?) -> Void) {
- binds[identifier.uid] = { [weak self] client, data in
- do {
- let res = try JSONDecoder().decode(WSEvent
.self, from: data)
- handler(client, res.payload)
- } catch {
- self?.logger?.log(.error(String(describing: error)), on: client.req)
- }
- }
- }
-
- public func bind(_ identifier: WSEventIdentifier, _ handler: @escaping (WSClient, P) -> Void) {
- binds[identifier.uid] = { [weak self] client, data in
- do {
- let res = try JSONDecoder().decode(WSEvent
.self, from: data)
- guard let payload = res.payload else { throw WSError(reason: "Unable to unwrap payload") }
- handler(client, payload)
- } catch {
- self?.logger?.log(.error(String(describing: error)), on: client.req)
- }
- }
- }
-
- /// Calls when a new client connects. Override this function to handle `onOpen`.
- open func onOpen(_ client: WSClient) {}
-
- /// Calls when a client disconnects. Override this function to handle `onClose`.
- open func onClose(_ client: WSClient) {}
-
- public override func wsOnOpen(_ ws: WS, _ client: WSClient) -> Bool {
- let result = super.wsOnOpen(ws, client)
- if result {
- onOpen(client)
- }
- return result
- }
-
- public override func wsOnClose(_ ws: WS, _ client: WSClient) {
- super.wsOnClose(ws, client)
- onClose(client)
- }
-
- public override func wsOnText(_ ws: WS, _ client: WSClient, _ text: String) {
- super.wsOnText(ws, client, text)
- if let data = text.data(using: .utf8) {
- proceedData(ws, client, data: data)
- }
- }
-
- public override func wsOnBinary(_ ws: WS, _ client: WSClient, _ data: Data) {
- super.wsOnBinary(ws, client, data)
- proceedData(ws, client, data: data)
- }
-
- func proceedData(_ ws: WS, _ client: WSClient, data: Data) {
- do {
- let prototype = try JSONDecoder().decode(WSEventPrototype.self, from: data)
- switch prototype.event {
- case "join": ws.joining(client, data: data, on: client.req)
- case "leave": ws.leaving(client, data: data, on: client.req)
- default: break
- }
- if let bind = binds.first(where: { $0.0 == prototype.event }) {
- bind.value(client, data)
- }
- } catch {
- logger?.log(.error(String(describing: error)), on: client.req)
- }
- }
-}
diff --git a/Sources/WS/Controllers/WSPureController.swift b/Sources/WS/Controllers/WSPureController.swift
deleted file mode 100644
index d35f0ec..0000000
--- a/Sources/WS/Controllers/WSPureController.swift
+++ /dev/null
@@ -1,43 +0,0 @@
-import Foundation
-
-open class WSPureController: WSObserver {
- public typealias OnOpenHandler = (WSClient) -> Void
- public var onOpen: OnOpenHandler?
-
- public typealias OnCloseHandler = () -> Void
- public var onClose: OnCloseHandler?
-
- public typealias OnTextHandler = (WSClient, String) -> Void
- public var onText: OnTextHandler?
-
- public typealias OnBinaryHandler = (WSClient, Data) -> Void
- public var onBinary: OnBinaryHandler?
-
- public typealias OnErrorHandler = (WSClient, Error) -> Void
- public var onError: OnErrorHandler?
-
- override public func wsOnOpen(_ ws: WS, _ client: WSClient) -> Bool {
- if super.wsOnOpen(ws, client) {
- onOpen?(client)
- return true
- }
- return false
- }
-
- override public func wsOnClose(_ ws: WS, _ client: WSClient) {
- super.wsOnClose(ws, client)
- onClose?()
- }
-
- override public func wsOnText(_ ws: WS, _ client: WSClient, _ text: String) {
- onText?(client, text)
- }
-
- override public func wsOnBinary(_ ws: WS, _ client: WSClient, _ data: Data) {
- onBinary?(client, data)
- }
-
- override public func wsOnError(_ ws: WS, _ client: WSClient, _ error: Error) {
- onError?(client, error)
- }
-}
diff --git a/Sources/WS/Enums/ExchangeMode.swift b/Sources/WS/Enums/ExchangeMode.swift
new file mode 100644
index 0000000..86c9396
--- /dev/null
+++ b/Sources/WS/Enums/ExchangeMode.swift
@@ -0,0 +1,12 @@
+public enum ExchangeMode {
+ /// all the messages will be sent and received as `text`
+ case text
+
+ /// all the messages will be sent and received as `binary data`
+ case binary
+
+ /// default mode
+ /// `binary data` will be sent and received as is
+ /// `text` will be sent and received as is
+ case both
+}
diff --git a/Sources/WS/Extensions/Application+Configurator.swift b/Sources/WS/Extensions/Application+Configurator.swift
new file mode 100644
index 0000000..0bcb753
--- /dev/null
+++ b/Sources/WS/Extensions/Application+Configurator.swift
@@ -0,0 +1,31 @@
+import Vapor
+
+extension Application {
+ /// Configure WS through this variable
+ ///
+ /// Declare WSID in extension
+ /// ```swift
+ /// extension WSID {
+ /// static var my: WSID { .init() }
+ /// }
+ /// ```
+ ///
+ /// Configure endpoint and start it serving
+ /// ```swift
+ /// app.ws.build(.my).at("ws").middlewares(...).serve()
+ /// app.ws.setDefault(.my)
+ /// ```
+ ///
+ /// Use it later on `Request`
+ /// ```swift
+ /// req.ws().send(...)
+ /// req.ws(.my).send(...)
+ /// ```
+ /// or `Application`
+ /// ```swift
+ /// app.ws.observer().send(...)
+ /// app.ws.observer(.my).send(...)
+ /// ```
+ ///
+ public var ws: Configurator { .init(self) }
+}
diff --git a/Sources/WS/Extensions/HTTPServerConfiguration+Address.swift b/Sources/WS/Extensions/HTTPServerConfiguration+Address.swift
new file mode 100644
index 0000000..e85f386
--- /dev/null
+++ b/Sources/WS/Extensions/HTTPServerConfiguration+Address.swift
@@ -0,0 +1,8 @@
+import Vapor
+
+extension HTTPServer.Configuration {
+ var address: String {
+ let scheme = tlsConfiguration == nil ? "http" : "https"
+ return "\(scheme)://\(hostname):\(port)"
+ }
+}
diff --git a/Sources/WS/Extensions/Request+Observer.swift b/Sources/WS/Extensions/Request+Observer.swift
new file mode 100644
index 0000000..1f23199
--- /dev/null
+++ b/Sources/WS/Extensions/Request+Observer.swift
@@ -0,0 +1,9 @@
+import Vapor
+
+extension Request {
+ /// Default websocket observer
+ public func ws() -> AnyObserver { application.ws.observer() }
+
+ /// Selected websocket observer
+ public func ws(_ wsid: WSID) -> Observer { application.ws.observer(wsid) }
+}
diff --git a/Sources/WS/Models/Configurator.swift b/Sources/WS/Models/Configurator.swift
new file mode 100644
index 0000000..41d0269
--- /dev/null
+++ b/Sources/WS/Models/Configurator.swift
@@ -0,0 +1,98 @@
+import Vapor
+
+public struct Configurator {
+ let application: Application
+
+ init (_ application: Application) {
+ self.application = application
+ }
+
+ // MARK: - Build
+
+ /// Websocket endpoint builder.
+ /// Don't forget to call `.serve()` in the end.
+ public func build(_ wsid: WSID) -> EndpointBuilder {
+ .init(application, wsid)
+ }
+
+ // MARK: - Observer
+
+ /// Returns default observer.
+ /// Works only after `.build()`, otherwise fatal error.
+ public func observer() -> AnyObserver {
+ var anywsid: AnyWSID? = application.ws.default
+ if anywsid == nil, let key = application.wsStorage.items.values.first?.key {
+ anywsid = _WSID(key: key)
+ application.logger.warning("[⚡️] 🚩 Default websocket observer is nil. Use app.ws.setDefault(...). Used first available websocket.")
+ }
+ guard let wsid = anywsid else {
+ fatalError("[⚡️] 🚩Default websocket observer is nil. Use app.ws.default(...)")
+ }
+ guard let observer = application.wsStorage[wsid.key] else {
+ fatalError("[⚡️] 🚩Unable to get websocket observer with key `\(wsid.key)`")
+ }
+ return observer
+ }
+
+ /// Returns observer for WSID.
+ /// Works only after `.build()`, otherwise fatal error.
+ public func observer(_ wsid: WSID) -> Observer {
+ guard let observer = application.wsStorage[wsid.key] as? Observer else {
+ fatalError("[⚡️] 🚩Websokcet with key `\(wsid.key)` is not running. Use app.ws.build(...).serve()")
+ }
+ return observer
+ }
+
+ // MARK: - Default WSID storage
+
+ /// Saves WSID as default.
+ /// After that you could call just `req.ws().send(...)` without providing WSID.
+ public func setDefault(_ wsid: WSID) {
+ self.default = wsid
+ }
+
+ struct DefaultWSIDKey: StorageKey {
+ typealias Value = AnyWSID
+ }
+
+ var `default`: AnyWSID? {
+ get {
+ application.storage[DefaultWSIDKey.self]
+ }
+ nonmutating set {
+ application.storage[DefaultWSIDKey.self] = newValue
+ }
+ }
+
+ // MARK: - Default Encoder
+
+ struct DefaultEncoderKey: StorageKey {
+ typealias Value = Encoder
+ }
+
+ /// Default encoder for all the observers, if `nil` then `JSONEncoder` is used.
+ public var encoder: Encoder? {
+ get {
+ application.storage[DefaultEncoderKey.self]
+ }
+ nonmutating set {
+ application.storage[DefaultEncoderKey.self] = newValue
+ }
+ }
+
+ // MARK: - Default Decoder
+
+ struct DefaultDecoderKey: StorageKey {
+ typealias Value = Decoder
+ }
+
+ /// Default encoder for all the observers, if `nil` then `JSONEncoder` is used.
+ public var decoder: Decoder? {
+ get {
+ application.storage[DefaultDecoderKey.self]
+ }
+ nonmutating set {
+ application.storage[DefaultDecoderKey.self] = newValue
+ }
+ }
+}
diff --git a/Sources/WS/Models/EID.swift b/Sources/WS/Models/EID.swift
new file mode 100644
index 0000000..ab97f58
--- /dev/null
+++ b/Sources/WS/Models/EID.swift
@@ -0,0 +1,29 @@
+import Foundation
+
+/// Event identifier model
+///
+/// Extend it to declare your websocket events
+/// ```swift
+/// extension EID {
+/// static var userOnline: EID { .init("userOnline") }
+/// }
+/// ```
+public struct EID: Equatable, Hashable, CustomStringConvertible, ExpressibleByStringLiteral {
+ /// The unique id.
+ public let id: String
+
+ /// See `CustomStringConvertible`.
+ public var description: String {
+ return id
+ }
+
+ /// Create a new `EventIdentifier`.
+ public init(_ id: String) {
+ self.id = id
+ }
+
+ /// See `ExpressibleByStringLiteral`.
+ public init(stringLiteral value: String) {
+ self.init(value)
+ }
+}
diff --git a/Sources/WS/Models/Event.swift b/Sources/WS/Models/Event.swift
new file mode 100644
index 0000000..4a7fafb
--- /dev/null
+++ b/Sources/WS/Models/Event.swift
@@ -0,0 +1,14 @@
+import Foundation
+
+struct Event: Codable {
+ public let event: String
+ public let payload: P?
+ public init (event: String, payload: P? = nil) {
+ self.event = event
+ self.payload = payload
+ }
+}
+
+struct EventPrototype: Codable {
+ public var event: String
+}
diff --git a/Sources/WS/Models/Nothing.swift b/Sources/WS/Models/Nothing.swift
new file mode 100644
index 0000000..176b195
--- /dev/null
+++ b/Sources/WS/Models/Nothing.swift
@@ -0,0 +1,2 @@
+/// Dummy model for EIDs without payload
+public struct Nothing: Codable {}
diff --git a/Sources/WS/Models/OriginalRequest.swift b/Sources/WS/Models/OriginalRequest.swift
new file mode 100644
index 0000000..4cf2ab8
--- /dev/null
+++ b/Sources/WS/Models/OriginalRequest.swift
@@ -0,0 +1,71 @@
+//import Vapor
+//import NIO
+//
+///// Represent an original HTTP request of WebSocket client connection
+//public struct OriginalRequest: CustomStringConvertible {
+// /// The HTTP method for this request.
+// ///
+// /// httpReq.method = .GET
+// ///
+// public let method: HTTPMethod
+//
+// /// The URL used on this request.
+// public let url: URI
+//
+// /// The version for this HTTP request.
+// public let version: HTTPVersion
+//
+// /// The header fields for this HTTP request.
+// /// The `"Content-Length"` and `"Transfer-Encoding"` headers will be set automatically
+// /// when the `body` property is mutated.
+// public let headers: HTTPHeaders
+//
+// // MARK: Metadata
+//
+// /// Route object we found for this request.
+// /// This holds metadata that can be used for (for example) Metrics.
+// ///
+// /// req.route?.description // "GET /hello/:name"
+// ///
+// public let route: Route?
+//
+// // MARK: Content
+//
+// public let query: URLQueryContainer
+//
+// public let content: ContentContainer
+//
+// public let body: Request.Body
+//
+// /// Get and set `HTTPCookies` for this `HTTPRequest`
+// /// This accesses the `"Cookie"` header.
+// public let cookies: HTTPCookies
+//
+// /// See `CustomStringConvertible`
+// public let description: String
+//
+// public let remoteAddress: SocketAddress?
+//
+// public let eventLoop: EventLoop
+//
+// public let parameters: Parameters
+//
+// public let userInfo: [AnyHashable: Any]
+//
+// init(_ request: Request) {
+// method = request.method
+// url = request.url
+// version = request.version
+// headers = request.headers
+// route = request.route
+// query = request.query
+// content = request.content
+// body = request.body
+// cookies = request.cookies
+// description = request.description
+// remoteAddress = request.remoteAddress
+// eventLoop = request.eventLoop
+// parameters = request.parameters
+// userInfo = request.userInfo
+// }
+//}
diff --git a/Sources/WS/Models/WSID.swift b/Sources/WS/Models/WSID.swift
new file mode 100644
index 0000000..63d2c4e
--- /dev/null
+++ b/Sources/WS/Models/WSID.swift
@@ -0,0 +1,22 @@
+import Vapor
+
+public protocol AnyWSID {
+ var key: String { get }
+}
+
+struct _WSID: AnyWSID {
+ let key: String
+}
+
+public struct WSID: AnyWSID {
+ public let key: String
+
+ public init(_ key: String? = nil) {
+ self.key = key ?? String(describing: Observer.self)
+ }
+}
+
+/// Set WSIDs in your app exactly the same way
+extension WSID {
+ public static var `default`: WSID { .init("ws") }
+}
diff --git a/Sources/WS/Objects/Client.swift b/Sources/WS/Objects/Client.swift
new file mode 100644
index 0000000..8c50e5a
--- /dev/null
+++ b/Sources/WS/Objects/Client.swift
@@ -0,0 +1,69 @@
+import Foundation
+import Vapor
+import NIOWebSocket
+
+class Client: _AnyClient {
+ /// See `AnyClient`
+ public let id: UUID = .init()
+ public let originalRequest: Request
+ public let application: Application
+
+ /// See `Loggable`
+ public let logger: Logger
+
+ /// See `_Sendable`
+ let observer: AnyObserver
+ let _observer: _AnyObserver
+ let sockets: [WebSocketKit.WebSocket]
+
+ /// See `AnyClient`
+ public internal(set) var channels: Set = []
+
+ /// See `Subscribable`
+ var clients: [_AnyClient] { [self] }
+
+ init (_ observer: _AnyObserver, _ request: Vapor.Request, _ socket: WebSocketKit.WebSocket, logger: Logger) {
+ self.observer = observer
+ self._observer = observer
+ self.originalRequest = request
+ self.application = request.application
+ self.logger = logger
+ self.sockets = [socket]
+ }
+}
+
+/// See `Sendable`
+
+extension Client {
+ public func send(text: S) -> EventLoopFuture where S : Collection, S.Element == Character {
+ _send(text: text)
+ }
+
+ public func send(bytes: [UInt8]) -> EventLoopFuture {
+ _send(bytes: bytes)
+ }
+
+ public func send(data: Data) -> EventLoopFuture where Data : DataProtocol {
+ _send(data: data)
+ }
+
+ public func send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data: DataProtocol {
+ _send(data: data, opcode: opcode)
+ }
+
+ public func send(model: C) -> EventLoopFuture where C: Encodable {
+ _send(model: model)
+ }
+
+ public func send(model: C, encoder: Encoder) -> EventLoopFuture where C: Encodable {
+ _send(model: model, encoder: encoder)
+ }
+
+ public func send(event: EID) -> EventLoopFuture {
+ _send(event: event, payload: nil)
+ }
+
+ public func send(event: EID, payload: T?) -> EventLoopFuture {
+ _send(event: event, payload: payload)
+ }
+}
diff --git a/Sources/WS/Observers/BaseObserver.swift b/Sources/WS/Observers/BaseObserver.swift
new file mode 100644
index 0000000..a252d29
--- /dev/null
+++ b/Sources/WS/Observers/BaseObserver.swift
@@ -0,0 +1,36 @@
+import Foundation
+import Vapor
+
+open class BaseObserver {
+ public let key: String
+ public let path: String
+ public let logger: Logger
+ public let application: Application
+ public let exchangeMode: ExchangeMode
+ public var encoder: Encoder?
+ public var decoder: Decoder?
+
+ public internal(set) var clients: [AnyClient] = []
+ var _clients: [_AnyClient] = []
+
+ public required init (app: Application, key: String, path: String, exchangeMode: ExchangeMode) {
+ self.application = app
+ self.logger = app.logger
+ self.key = key
+ self.path = path.count > 0 ? path : "/"
+ self.exchangeMode = exchangeMode
+ setup()
+ }
+
+ open func setup() {}
+
+ // MARK: see `AnyObserver`
+
+ open func on(open client: AnyClient) {}
+ open func on(close client: AnyClient) {}
+ open func on(ping client: AnyClient) {}
+ open func on(pong client: AnyClient) {}
+ open func on(text: String, client: AnyClient) {}
+ open func on(byteBuffer: ByteBuffer, client: AnyClient) {}
+ open func on(data: Data, client: AnyClient) {}
+}
diff --git a/Sources/WS/Observers/BindableObserver.swift b/Sources/WS/Observers/BindableObserver.swift
new file mode 100644
index 0000000..ecb2c28
--- /dev/null
+++ b/Sources/WS/Observers/BindableObserver.swift
@@ -0,0 +1,54 @@
+import Foundation
+import Vapor
+import NIOWebSocket
+
+open class BindableObserver: BaseObserver, Bindable, _Bindable {
+ var binds: [String : BindHandler] = [:]
+
+ public func bind(_ identifier: EID
, _ handler: @escaping (AnyClient) -> Void) where P: Codable {
+ _bind(identifier, handler)
+ }
+
+ public func bindOptional
(_ identifier: EID
, _ handler: @escaping (AnyClient, P?) -> Void) where P : Codable {
+ _bindOptional(identifier, handler)
+ }
+
+ public func bind
(_ identifier: EID
, _ handler: @escaping (AnyClient, P) -> Void) where P : Codable {
+ _bind(identifier, handler)
+ }
+}
+
+/// See `Sendable`
+extension BindableObserver {
+ public func send(text: S) -> EventLoopFuture where S : Collection, S.Element == Character {
+ _send(text: text)
+ }
+
+ public func send(bytes: [UInt8]) -> EventLoopFuture {
+ _send(bytes: bytes)
+ }
+
+ public func send(data: Data) -> EventLoopFuture where Data : DataProtocol {
+ _send(data: data)
+ }
+
+ public func send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data: DataProtocol {
+ _send(data: data, opcode: opcode)
+ }
+
+ public func send(model: C) -> EventLoopFuture where C: Encodable {
+ _send(model: model)
+ }
+
+ public func send(model: C, encoder: Encoder) -> EventLoopFuture where C: Encodable {
+ _send(model: model, encoder: encoder)
+ }
+
+ public func send(event: EID) -> EventLoopFuture {
+ _send(event: event, payload: nil)
+ }
+
+ public func send(event: EID, payload: T?) -> EventLoopFuture {
+ _send(event: event, payload: payload)
+ }
+}
diff --git a/Sources/WS/Observers/ClassicObserver.swift b/Sources/WS/Observers/ClassicObserver.swift
new file mode 100644
index 0000000..dd450fb
--- /dev/null
+++ b/Sources/WS/Observers/ClassicObserver.swift
@@ -0,0 +1,41 @@
+import Foundation
+import Vapor
+import NIOWebSocket
+
+open class ClassicObserver: BaseObserver, _AnyObserver, AnyObserver {}
+
+/// See `Sendable`
+
+extension ClassicObserver {
+ public func send(text: S) -> EventLoopFuture where S : Collection, S.Element == Character {
+ _send(text: text)
+ }
+
+ public func send(bytes: [UInt8]) -> EventLoopFuture {
+ _send(bytes: bytes)
+ }
+
+ public func send(data: Data) -> EventLoopFuture where Data : DataProtocol {
+ _send(data: data)
+ }
+
+ public func send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data: DataProtocol {
+ _send(data: data, opcode: opcode)
+ }
+
+ public func send(model: C) -> EventLoopFuture where C: Encodable {
+ _send(model: model)
+ }
+
+ public func send(model: C, encoder: Encoder) -> EventLoopFuture where C: Encodable {
+ _send(model: model, encoder: encoder)
+ }
+
+ public func send(event: EID) -> EventLoopFuture {
+ _send(event: event, payload: nil)
+ }
+
+ public func send(event: EID, payload: T?) -> EventLoopFuture {
+ _send(event: event, payload: payload)
+ }
+}
diff --git a/Sources/WS/Observers/DeclarativeObserver.swift b/Sources/WS/Observers/DeclarativeObserver.swift
new file mode 100644
index 0000000..9cb5cf7
--- /dev/null
+++ b/Sources/WS/Observers/DeclarativeObserver.swift
@@ -0,0 +1,43 @@
+import Foundation
+import Vapor
+import NIOWebSocket
+
+open class DeclarativeObserver: BaseObserver, _Declarativable, Declarativable {
+ public internal(set) var handlers: DeclarativeHandlers = .init()
+}
+
+/// See `Sendable`
+
+extension DeclarativeObserver {
+ public func send(text: S) -> EventLoopFuture where S : Collection, S.Element == Character {
+ _send(text: text)
+ }
+
+ public func send(bytes: [UInt8]) -> EventLoopFuture {
+ _send(bytes: bytes)
+ }
+
+ public func send(data: Data) -> EventLoopFuture where Data : DataProtocol {
+ _send(data: data)
+ }
+
+ public func send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data: DataProtocol {
+ _send(data: data, opcode: opcode)
+ }
+
+ public func send(model: C) -> EventLoopFuture where C: Encodable {
+ _send(model: model)
+ }
+
+ public func send(model: C, encoder: Encoder) -> EventLoopFuture where C: Encodable {
+ _send(model: model, encoder: encoder)
+ }
+
+ public func send(event: EID) -> EventLoopFuture {
+ _send(event: event, payload: nil)
+ }
+
+ public func send(event: EID, payload: T?) -> EventLoopFuture {
+ _send(event: event, payload: payload)
+ }
+}
diff --git a/Sources/WS/Protocols/AnyClient.swift b/Sources/WS/Protocols/AnyClient.swift
new file mode 100644
index 0000000..4727957
--- /dev/null
+++ b/Sources/WS/Protocols/AnyClient.swift
@@ -0,0 +1,33 @@
+import Foundation
+import Vapor
+import NIOWebSocket
+
+public protocol AnyClient: Broadcastable, Disconnectable, Subscribable, Sendable {
+ var id: UUID { get }
+ var application: Application { get }
+ var eventLoop: EventLoop { get }
+ var originalRequest: Request { get }
+ var channels: Set { get }
+ var sockets: [WebSocket] { get }
+ var observer: AnyObserver { get }
+}
+
+internal protocol _AnyClient: AnyClient, _Disconnectable, _Subscribable, _Sendable {
+ var _observer: _AnyObserver { get }
+ var channels: Set { get set }
+}
+
+extension AnyClient {
+ public var eventLoop: EventLoop { application.eventLoopGroup.next() }
+ public var logger: Logger { application.logger }
+
+ /// See `Broadcastable`
+ public var broadcast: Broadcaster {
+ observer.broadcast
+ }
+}
+
+extension _AnyClient {
+ public var exchangeMode: ExchangeMode { observer.exchangeMode }
+ var _encoder: Encoder { observer._encoder }
+}
diff --git a/Sources/WS/Protocols/AnyObserver.swift b/Sources/WS/Protocols/AnyObserver.swift
new file mode 100644
index 0000000..29db098
--- /dev/null
+++ b/Sources/WS/Protocols/AnyObserver.swift
@@ -0,0 +1,167 @@
+import Vapor
+
+public protocol AnyObserver: class, Broadcastable, CustomStringConvertible, Disconnectable, Sendable, Loggable {
+ var key: String { get }
+ var path: String { get }
+
+ var application: Application { get }
+ var eventLoop: EventLoop { get }
+ var clients: [AnyClient] { get }
+ var encoder: Encoder? { get set }
+ var decoder: Decoder? { get set }
+ var exchangeMode: ExchangeMode { get }
+
+ init (app: Application, key: String, path: String, exchangeMode: ExchangeMode)
+
+ func setup()
+
+ func on(open client: AnyClient)
+ func on(close client: AnyClient)
+ func on(ping client: AnyClient)
+ func on(pong client: AnyClient)
+ func on(text: String, client: AnyClient)
+ func on(byteBuffer: ByteBuffer, client: AnyClient)
+ func on(data: Data, client: AnyClient)
+}
+
+internal protocol _AnyObserver: AnyObserver, _Disconnectable, _Sendable {
+ var _clients: [_AnyClient] { get set }
+ var _encoder: Encoder { get }
+ var _decoder: Decoder { get }
+
+ func _on(open client: _AnyClient)
+ func _on(close client: _AnyClient)
+ func _on(ping client: _AnyClient)
+ func _on(pong client: _AnyClient)
+ func _on(text: String, client: _AnyClient)
+ func _on(byteBuffer: ByteBuffer, client: _AnyClient)
+ func _on(data: Data, client: _AnyClient)
+}
+
+// MARK: - Default implementation
+
+extension AnyObserver {
+ public var eventLoop: EventLoop { application.eventLoopGroup.next() }
+
+ var _encoder: Encoder {
+ if let encoder = self.encoder {
+ return encoder
+ }
+ if let encoder = application.ws.encoder {
+ return encoder
+ }
+ let encoder = JSONEncoder()
+ encoder.dateEncodingStrategy = .formatted(DefaultDateFormatter())
+ return encoder
+ }
+
+ var _decoder: Decoder {
+ if let decoder = self.decoder {
+ return decoder
+ }
+ if let decoder = application.ws.decoder {
+ return decoder
+ }
+ let decoder = JSONDecoder()
+ decoder.dateDecodingStrategy = .formatted(DefaultDateFormatter())
+ return decoder
+ }
+
+ func handle(_ req: Request, _ ws: WebSocketKit.WebSocket) {
+ guard let self = self as? _AnyObserver else { return }
+ self.handle(req, ws)
+ }
+
+ /// See `Broadcastable`
+
+ public var broadcast: Broadcaster {
+ .init(eventLoop: eventLoop,
+ clients: clients,
+ exchangeMode: exchangeMode,
+ logger: application.logger,
+ encoder: encoder,
+ defaultEncoder: application.ws.encoder)
+ }
+
+ /// see `CustomStringConvertible`
+ public var description: String {
+ "\(String(describing: Self.self))(key: \"\(key)\", at: \"\(path)\")"
+ }
+}
+
+extension _AnyObserver {
+ var clients: [AnyClient] { _clients }
+ var observer: _AnyObserver { self }
+ var sockets: [WebSocket] { _clients.flatMap { $0.sockets } }
+
+ /// Internal handler
+
+ func handle(_ req: Request, _ ws: WebSocketKit.WebSocket) {
+ let client = Client(self, req, ws, logger: logger)
+ _clients.append(client)
+
+ _on(open: client)
+ on(open: client)
+ logger.info("[⚡️] 🟢 new connection \(client.id)")
+
+ _ = ws.onClose.map {
+ self.logger.info("[⚡️] 🔴 connection closed \(client.id)")
+ self._clients.removeAll(where: { $0 === client })
+ self._on(close: client)
+ self.on(close: client)
+ }
+
+ ws.onPing { _ in
+ self.logger.debug("[⚡️] 🏓 ping \(client.id)")
+ self._on(ping: client)
+ self.on(ping: client)
+ }
+
+ ws.onPong { _ in
+ self.logger.debug("[⚡️] 🏓 pong \(client.id)")
+ self._on(pong: client)
+ self.on(pong: client)
+ }
+
+ ws.onText { _, text in
+ guard self.exchangeMode != .binary else {
+ self.logger.warning("[⚡️] ❗️📤❗️incoming text event has been rejected. Observer is in `binary` mode.")
+ return
+ }
+ self.logger.debug("[⚡️] 📥 \(client.id) text: \(text)")
+ self._on(text: text, client: client)
+ self.on(text: text, client: client)
+ }
+
+ ws.onBinary { _, byteBuffer in
+ guard self.exchangeMode != .text else {
+ self.logger.warning("[⚡️] ❗️📤❗️incoming binary event has been rejected. Observer is in `text` mode.")
+ return
+ }
+ self.logger.debug("[⚡️] 📥 \(client.id) data: \(byteBuffer.readableBytes)")
+ self._on(byteBuffer: byteBuffer, client: client)
+ self.on(byteBuffer: byteBuffer, client: client)
+ guard byteBuffer.readableBytes > 0 else { return }
+ var bytes: [UInt8] = byteBuffer.getBytes(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes) ?? []
+ let data = Data(bytes: &bytes, count: byteBuffer.readableBytes)
+ self._on(data: data, client: client)
+ self.on(data: data, client: client)
+ }
+ }
+
+ public func on(open client: AnyClient) {}
+ public func on(close client: AnyClient) {}
+ public func on(ping client: AnyClient) {}
+ public func on(pong client: AnyClient) {}
+ public func on(text: String, client: AnyClient) {}
+ public func on(byteBuffer: ByteBuffer, client: AnyClient) {}
+ public func on(data: Data, client: AnyClient) {}
+
+ func _on(open client: _AnyClient) {}
+ func _on(close client: _AnyClient) {}
+ func _on(ping client: _AnyClient) {}
+ func _on(pong client: _AnyClient) {}
+ func _on(text: String, client: _AnyClient) {}
+ func _on(byteBuffer: ByteBuffer, client: _AnyClient) {}
+ func _on(data: Data, client: _AnyClient) {}
+}
diff --git a/Sources/WS/Protocols/Bindable.swift b/Sources/WS/Protocols/Bindable.swift
new file mode 100644
index 0000000..9e8e25c
--- /dev/null
+++ b/Sources/WS/Protocols/Bindable.swift
@@ -0,0 +1,109 @@
+import Foundation
+import NIO
+
+typealias BindHandler = (AnyClient, Data) -> Void
+
+public protocol Bindable: AnyObserver {
+ /// Binds to event without payload
+ ///
+ /// - parameters:
+ /// - identifier: `EID` event identifier, declare it in extension
+ /// - handler: called when event happens
+ func bind(_ identifier: EID
, _ handler: @escaping (AnyClient) -> Void) where P: Codable
+
+ /// Binds to event with optional payload
+ ///
+ /// - parameters:
+ /// - identifier: `EID` event identifier, declare it in extension
+ /// - handler: called when event happens
+ func bindOptional
(_ identifier: EID
, _ handler: @escaping (AnyClient, P?) -> Void) where P: Codable
+
+ /// Binds to event with required payload
+ ///
+ /// - parameters:
+ /// - identifier: `EID` event identifier, declare it in extension
+ /// - handler: called when event happens
+ func bind
(_ identifier: EID
, _ handler: @escaping (AnyClient, P) -> Void) where P: Codable
+}
+
+internal protocol _Bindable: Bindable, _AnyObserver {
+ var binds: [String: BindHandler] { get set }
+}
+
+extension _Bindable {
+ func _bind(_ identifier: EID, _ handler: @escaping (AnyClient) -> Void) {
+ bindOptional(identifier) { client, _ in
+ handler(client)
+ }
+ }
+
+ func _bindOptional(_ identifier: EID, _ handler: @escaping (AnyClient, P?) -> Void) {
+ binds[identifier.id] = { client, data in
+ do {
+ let res = try self._decoder.decode(Event
.self, from: data)
+ handler(client, res.payload)
+ } catch {
+ self.unableToDecode(identifier.id, error)
+ }
+ }
+ }
+
+ func _bind(_ identifier: EID, _ handler: @escaping (AnyClient, P) -> Void) {
+ binds[identifier.id] = { client, data in
+ do {
+ let res = try self._decoder.decode(Event
.self, from: data)
+ if let payload = res.payload {
+ handler(client, payload)
+ } else {
+ self.logger.warning("[⚡️] ❗️📥❗️Unable to unwrap payload for event `\(identifier.id)`, because it is unexpectedly nil. Please use another `bind` method which support optional payload to avoid this message.")
+ }
+ } catch {
+ self.unableToDecode(identifier.id, error)
+ }
+ }
+ }
+
+ private func unableToDecode(_ id: String, _ error: Error) {
+ switch logger.logLevel {
+ case .debug: logger.debug("[⚡️] ❗️📥❗️Undecodable incoming event `\(id)`: \(error)")
+ default: logger.error("[⚡️] ❗️📥❗️Unable to decode incoming event `\(id)`")
+ }
+ }
+}
+
+extension _Bindable {
+ func _on(text: String, client: _AnyClient) {
+ if let data = text.data(using: .utf8) {
+ proceedData(client, data)
+ }
+ }
+
+ func _on(byteBuffer: ByteBuffer, client: _AnyClient) {
+ guard byteBuffer.readableBytes > 0 else { return }
+ var bytes: [UInt8] = byteBuffer.getBytes(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes) ?? []
+ let data = Data(bytes: &bytes, count: byteBuffer.readableBytes)
+ proceedData(client, data)
+ }
+
+ func _on(data: Data, client: _AnyClient) {
+ proceedData(client, data)
+ }
+
+ private func proceedData(_ client: _AnyClient, _ data: Data) {
+ do {
+ let prototype = try _decoder.decode(EventPrototype.self, from: data)
+ if let bind = binds.first(where: { $0.0 == prototype.event }) {
+ bind.value(client, data)
+ }
+ } catch {
+ unableToDecode(error)
+ }
+ }
+
+ private func unableToDecode(_ error: Error) {
+ switch logger.logLevel {
+ case .debug: logger.debug("[⚡️] ❗️📥❗️Unable to decode incoming event cause it doesn't conform to `EventPrototype` model: \(error)")
+ default: logger.error("[⚡️] ❗️📥❗️Unable to decode incoming event cause it doesn't conform to `EventPrototype` model")
+ }
+ }
+}
diff --git a/Sources/WS/Protocols/Broadcastable.swift b/Sources/WS/Protocols/Broadcastable.swift
new file mode 100644
index 0000000..7b8889c
--- /dev/null
+++ b/Sources/WS/Protocols/Broadcastable.swift
@@ -0,0 +1,6 @@
+import Foundation
+import NIO
+
+public protocol Broadcastable {
+ var broadcast: Broadcaster { get }
+}
diff --git a/Sources/WS/Protocols/Declarativable.swift b/Sources/WS/Protocols/Declarativable.swift
new file mode 100644
index 0000000..b4cab4c
--- /dev/null
+++ b/Sources/WS/Protocols/Declarativable.swift
@@ -0,0 +1,153 @@
+import Vapor
+
+public typealias EmptyHandler = () -> Void
+public typealias OpenCloseHandler = (AnyClient) -> Void
+public typealias TextHandler = (AnyClient, String) -> Void
+public typealias ByteBufferHandler = (AnyClient, ByteBuffer) -> Void
+public typealias BinaryHandler = (AnyClient, Data) -> Void
+
+public class DeclarativeHandlers {
+ var openHander: OpenCloseHandler?
+ var closeHander: OpenCloseHandler?
+ var pingHander: OpenCloseHandler?
+ var pongHander: OpenCloseHandler?
+ var textHander: TextHandler?
+ var byteBufferHander: ByteBufferHandler?
+ var binaryHander: BinaryHandler?
+}
+
+public protocol Declarativable: AnyObserver {
+ var handlers: DeclarativeHandlers { get }
+}
+
+internal protocol _Declarativable: Declarativable, _AnyObserver {
+ var handlers: DeclarativeHandlers { get set }
+}
+
+extension _Declarativable {
+ func _on(open client: _AnyClient) {
+ handlers.openHander?(client)
+ }
+
+ func _on(close client: _AnyClient) {
+ handlers.closeHander?(client)
+ }
+
+ func _on(ping client: _AnyClient) {
+ handlers.pingHander?(client)
+ }
+
+ func _on(pong client: _AnyClient) {
+ handlers.pongHander?(client)
+ }
+
+ func _on(text: String, client: _AnyClient) {
+ guard let handler = handlers.textHander else {
+ logger.warning("[⚡️] ❗️📥❗️ \(description) received `text` but handler is nil")
+ return
+ }
+ handler(client, text)
+ }
+
+ func _on(byteBuffer: ByteBuffer, client: _AnyClient) {
+ guard let handler = handlers.byteBufferHander else {
+ logger.warning("[⚡️] ❗️📥❗️ \(description) received `byteBuffer` but handler is nil")
+ return
+ }
+ handler(client, byteBuffer)
+ }
+
+ func _on(data: Data, client: _AnyClient) {
+ guard let handler = handlers.binaryHander else {
+ logger.warning("[⚡️] ❗️📥❗️ \(description) received `binary data` but handler is nil")
+ return
+ }
+ handler(client, data)
+ }
+}
+
+extension Declarativable {
+ @discardableResult
+ public func onOpen(_ handler: @escaping OpenCloseHandler) -> Self {
+ handlers.openHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onOpen(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.openHander = { _ in handler() }
+ return self
+ }
+
+ @discardableResult
+ public func onClose(_ handler: @escaping OpenCloseHandler) -> Self {
+ handlers.closeHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onClose(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.closeHander = { _ in handler() }
+ return self
+ }
+
+ @discardableResult
+ public func onPing(_ handler: @escaping OpenCloseHandler) -> Self {
+ handlers.pingHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onPing(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.pingHander = { _ in handler() }
+ return self
+ }
+
+ @discardableResult
+ public func onPong(_ handler: @escaping OpenCloseHandler) -> Self {
+ handlers.pongHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onPong(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.pongHander = { _ in handler() }
+ return self
+ }
+
+ @discardableResult
+ public func onText(_ handler: @escaping TextHandler) -> Self {
+ handlers.textHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onText(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.textHander = { _,_ in handler() }
+ return self
+ }
+
+ @discardableResult
+ public func onByteBuffer(_ handler: @escaping ByteBufferHandler) -> Self {
+ handlers.byteBufferHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onByteBuffer(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.byteBufferHander = { _,_ in handler() }
+ return self
+ }
+
+ @discardableResult
+ public func onBinary(_ handler: @escaping BinaryHandler) -> Self {
+ handlers.binaryHander = handler
+ return self
+ }
+
+ @discardableResult
+ public func onBinary(_ handler: @escaping EmptyHandler) -> Self {
+ handlers.binaryHander = { _,_ in handler() }
+ return self
+ }
+}
diff --git a/Sources/WS/Protocols/Decoder.swift b/Sources/WS/Protocols/Decoder.swift
new file mode 100644
index 0000000..581adf9
--- /dev/null
+++ b/Sources/WS/Protocols/Decoder.swift
@@ -0,0 +1,7 @@
+import Foundation
+
+public protocol Decoder {
+ func decode(_ type: T.Type, from data: Data) throws -> T where T : Decodable
+}
+
+extension JSONDecoder: Decoder {}
diff --git a/Sources/WS/Protocols/Delegate.swift b/Sources/WS/Protocols/Delegate.swift
new file mode 100644
index 0000000..9911810
--- /dev/null
+++ b/Sources/WS/Protocols/Delegate.swift
@@ -0,0 +1,9 @@
+//import Foundation
+//
+//public protocol Delegate {
+// func wsOnOpen(_ ws: WS, _ client: Client) -> Bool
+// func wsOnClose(_ ws: WS, _ client: Client)
+// func wsOnText(_ ws: WS, _ client: Client, _ text: String)
+// func wsOnBinary(_ ws: WS, _ client: Client, _ data: Data)
+// func wsOnError(_ ws: WS, _ client: Client, _ error: Error)
+//}
diff --git a/Sources/WS/Protocols/Disconnectable.swift b/Sources/WS/Protocols/Disconnectable.swift
new file mode 100644
index 0000000..31fa672
--- /dev/null
+++ b/Sources/WS/Protocols/Disconnectable.swift
@@ -0,0 +1,49 @@
+import Vapor
+import NIOWebSocket
+
+public protocol Disconnectable {
+ @discardableResult
+ func disconnect() -> EventLoopFuture
+ @discardableResult
+ func disconnect(code: WebSocketErrorCode) -> EventLoopFuture
+}
+
+internal protocol _Disconnectable: Disconnectable {
+ var eventLoop: EventLoop { get }
+ var sockets: [WebSocket] { get }
+}
+
+extension _Disconnectable {
+ public func disconnect() -> EventLoopFuture {
+ _disconnect()
+ }
+
+ public func disconnect(code: WebSocketErrorCode) -> EventLoopFuture {
+ _disconnect(code: code)
+ }
+
+ func _disconnect() -> EventLoopFuture {
+ eventLoop.future().flatMap {
+ self._disconnect(code: .goingAway)
+ }
+ }
+
+ func _disconnect(code: WebSocketErrorCode) -> EventLoopFuture {
+ guard sockets.count > 0 else { return eventLoop.future() }
+ return sockets.map {
+ $0.close(code: code)
+ }.flatten(on: eventLoop)
+ }
+}
+
+// MARK: - EventLoopFuture
+
+extension EventLoopFuture: Disconnectable where Value: Disconnectable {
+ public func disconnect() -> EventLoopFuture {
+ flatMap { $0.disconnect() }
+ }
+
+ public func disconnect(code: WebSocketErrorCode) -> EventLoopFuture {
+ flatMap { $0.disconnect(code: code) }
+ }
+}
diff --git a/Sources/WS/Protocols/Encoder.swift b/Sources/WS/Protocols/Encoder.swift
new file mode 100644
index 0000000..a191583
--- /dev/null
+++ b/Sources/WS/Protocols/Encoder.swift
@@ -0,0 +1,7 @@
+import Foundation
+
+public protocol Encoder {
+ func encode(_ value: T) throws -> Data where T : Encodable
+}
+
+extension JSONEncoder: Encoder {}
diff --git a/Sources/WS/Protocols/Loggable.swift b/Sources/WS/Protocols/Loggable.swift
new file mode 100644
index 0000000..792d6c0
--- /dev/null
+++ b/Sources/WS/Protocols/Loggable.swift
@@ -0,0 +1,5 @@
+import Logging
+
+public protocol Loggable {
+ var logger: Logger { get }
+}
diff --git a/Sources/WS/Protocols/Sendable.swift b/Sources/WS/Protocols/Sendable.swift
new file mode 100644
index 0000000..004d051
--- /dev/null
+++ b/Sources/WS/Protocols/Sendable.swift
@@ -0,0 +1,153 @@
+import Vapor
+import NIOWebSocket
+
+public protocol Sendable {
+ @discardableResult
+ func send(text: S) -> EventLoopFuture where S: Collection, S.Element == Character
+ @discardableResult
+ func send(bytes: [UInt8]) -> EventLoopFuture
+ @discardableResult
+ func send(data: Data) -> EventLoopFuture where Data: DataProtocol
+ @discardableResult
+ func send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data: DataProtocol
+ @discardableResult
+ func send(model: C) -> EventLoopFuture where C: Encodable
+ @discardableResult
+ func send(model: C, encoder: Encoder) -> EventLoopFuture where C: Encodable
+ @discardableResult
+ func send(event: EID) -> EventLoopFuture
+ @discardableResult
+ func send(event: EID, payload: T?) -> EventLoopFuture
+}
+
+internal protocol _Sendable: Sendable {
+ var eventLoop: EventLoop { get }
+ var exchangeMode: ExchangeMode { get }
+ var logger: Logger { get }
+ var _encoder: Encoder { get }
+ var sockets: [WebSocket] { get }
+}
+
+extension _Sendable {
+ func _send(text: S) -> EventLoopFuture where S : Collection, S.Element == Character {
+ /// Send as `binary` instead
+ if exchangeMode == .binary {
+ self.logger.warning("[⚡️] ❗️📤❗️text will be automatically converted to binary data. Observer is in `binary` mode.")
+ return send(bytes: String(text).utf8.map{ UInt8($0) })
+ }
+ /// Send as `text`
+ return eventLoop.future().map {
+ self.sockets.forEach {
+ self.logger.debug("[⚡️] 📤 text: \(text)")
+ $0.send(text)
+ }
+ }
+ }
+
+ func _send(bytes: [UInt8]) -> EventLoopFuture {
+ /// Send as `text` instead
+ if exchangeMode == .text {
+ self.logger.warning("[⚡️] ❗️📤❗️bytes will be automatically converted to text. Observer is in `binary` mode.")
+ guard let text = String(bytes: bytes, encoding: .utf8) else {
+ self.logger.warning("[⚡️] ❗️📤❗️Unable to convert bytes to text. Observer is in `binary` mode.")
+ return eventLoop.future()
+ }
+ return send(text: text)
+ }
+ /// Send as `binary`
+ return eventLoop.future().map {
+ self.sockets.forEach {
+ self.logger.debug("[⚡️] 📤 bytes: \(bytes.count)")
+ $0.send(bytes)
+ }
+ }
+ }
+
+ func _send(data: Data) -> EventLoopFuture where Data : DataProtocol {
+ send(data: data, opcode: .binary)
+ }
+
+ func _send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data: DataProtocol {
+ /// Send as `text` instead
+ if exchangeMode == .text {
+ self.logger.warning("[⚡️] ❗️📤❗️data will be automatically converted to text. Observer is in `text` mode.")
+ guard let text = String(bytes: data, encoding: .utf8) else {
+ self.logger.warning("[⚡️] ❗️📤❗️Unable to convert data to text. Observer is in `text` mode.")
+ return eventLoop.future()
+ }
+ return send(text: text)
+ }
+ /// Send as `binary`
+ return eventLoop.future().map {
+ self.sockets.forEach {
+ self.logger.debug("[⚡️] 📤 data: \(data.count)")
+ $0.send(raw: data, opcode: opcode)
+ }
+ }
+ }
+
+ func _send(model: C) -> EventLoopFuture where C: Encodable {
+ send(model: model, encoder: _encoder)
+ }
+
+ func _send(model: C, encoder: Encoder) -> EventLoopFuture where C: Encodable {
+ eventLoop.future().flatMapThrowing {
+ try encoder.encode(model)
+ }.flatMap { data -> EventLoopFuture in
+ if self.exchangeMode == .text {
+ return self.eventLoop.future(data)
+ }
+ return self.send(data: data).transform(to: data)
+ }.flatMap {
+ guard self.exchangeMode != .binary,
+ let text = String(data: $0, encoding: .utf8) else {
+ return self.eventLoop.future()
+ }
+ return self.send(text: text)
+ }
+ }
+
+ func _send(event: EID) -> EventLoopFuture {
+ send(event: event, payload: nil)
+ }
+
+ func _send(event: EID, payload: T?) -> EventLoopFuture {
+ send(model: Event(event: event.id, payload: payload))
+ }
+}
+
+// MARK: - EventLoopFuture
+
+extension EventLoopFuture: Sendable where Value: Sendable {
+ public func send(text: S) -> EventLoopFuture where S : Collection, S.Element == Character {
+ flatMap { $0.send(text: text) }
+ }
+
+ public func send(bytes: [UInt8]) -> EventLoopFuture {
+ flatMap { $0.send(bytes: bytes) }
+ }
+
+ public func send(data: Data) -> EventLoopFuture where Data : DataProtocol {
+ flatMap { $0.send(data: data) }
+ }
+
+ public func send(data: Data, opcode: WebSocketOpcode) -> EventLoopFuture where Data : DataProtocol {
+ flatMap { $0.send(data: data, opcode: opcode) }
+ }
+
+ public func send(model: C) -> EventLoopFuture where C : Encodable {
+ flatMap { $0.send(model: model) }
+ }
+
+ public func send(model: C, encoder: Encoder) -> EventLoopFuture where C : Encodable {
+ flatMap { $0.send(model: model, encoder: encoder) }
+ }
+
+ public func send(event: EID) -> EventLoopFuture where T : Decodable, T : Encodable {
+ flatMap { $0.send(event: event) }
+ }
+
+ public func send(event: EID, payload: T?) -> EventLoopFuture where T : Decodable, T : Encodable {
+ flatMap { $0.send(event: event, payload: payload) }
+ }
+}
diff --git a/Sources/WS/Protocols/Subscribable.swift b/Sources/WS/Protocols/Subscribable.swift
new file mode 100644
index 0000000..14de466
--- /dev/null
+++ b/Sources/WS/Protocols/Subscribable.swift
@@ -0,0 +1,68 @@
+import Vapor
+import NIOWebSocket
+
+public protocol Subscribable {
+ @discardableResult
+ func subscribe(to channels: String...) -> EventLoopFuture
+ @discardableResult
+ func subscribe(to channels: [String]) -> EventLoopFuture
+ @discardableResult
+ func unsubscribe(from channels: String...) -> EventLoopFuture
+ @discardableResult
+ func unsubscribe(from channels: [String]) -> EventLoopFuture
+}
+
+extension Subscribable {
+ public func subscribe(to channels: String...) -> EventLoopFuture {
+ subscribe(to: channels)
+ }
+
+ public func unsubscribe(from channels: String...) -> EventLoopFuture {
+ unsubscribe(from: channels)
+ }
+}
+
+internal protocol _Subscribable: class, Subscribable {
+ var eventLoop: EventLoop { get }
+ var clients: [_AnyClient] { get }
+}
+
+extension _Subscribable {
+ public func subscribe(to channels: String...) -> EventLoopFuture {
+ subscribe(to: channels)
+ }
+
+ public func subscribe(to channels: [String]) -> EventLoopFuture {
+ channels.forEach { channel in
+ self.clients.forEach {
+ $0.channels.insert(channel)
+ }
+ }
+ return eventLoop.future()
+ }
+
+ public func unsubscribe(from channels: String...) -> EventLoopFuture {
+ unsubscribe(from: channels)
+ }
+
+ public func unsubscribe(from channels: [String]) -> EventLoopFuture {
+ channels.forEach { channel in
+ self.clients.forEach {
+ $0.channels.remove(channel)
+ }
+ }
+ return eventLoop.future()
+ }
+}
+
+// MARK: - EventLoopFuture
+
+extension EventLoopFuture: Subscribable where Value: Subscribable {
+ public func subscribe(to channels: [String]) -> EventLoopFuture {
+ flatMap { $0.subscribe(to: channels) }
+ }
+
+ public func unsubscribe(from channels: [String]) -> EventLoopFuture {
+ flatMap { $0.unsubscribe(from: channels) }
+ }
+}
diff --git a/Sources/WS/Protocols/WSBroadcastable.swift b/Sources/WS/Protocols/WSBroadcastable.swift
deleted file mode 100644
index 10714fa..0000000
--- a/Sources/WS/Protocols/WSBroadcastable.swift
+++ /dev/null
@@ -1,21 +0,0 @@
-import Foundation
-import Vapor
-
-public protocol WSBroadcastable: class {
- func broadcast(_ text: String, to clients: Set, on container: Container) throws -> Future
- func broadcast(_ binary: Data, to clients: Set