Skip to content

Commit

Permalink
Use swift-atomics instead of NIOAtomics (#603)
Browse files Browse the repository at this point in the history
`NIOAtomics` was deprecated in apple/swift-nio#2204 in favor of `swift-atomics` https://github.com/apple/swift-atomics
  • Loading branch information
dnadoba authored Jul 13, 2022
1 parent 1af18d2 commit 2adca4b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 32 deletions.
3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.10.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
],
targets: [
.target(name: "CAsyncHTTPClient"),
Expand All @@ -46,6 +47,7 @@ let package = Package(
.product(name: "NIOSOCKS", package: "swift-nio-extras"),
.product(name: "NIOTransportServices", package: "swift-nio-transport-services"),
.product(name: "Logging", package: "swift-log"),
.product(name: "Atomics", package: "swift-atomics"),
]
),
.testTarget(
Expand All @@ -61,6 +63,7 @@ let package = Package(
.product(name: "NIOHTTP2", package: "swift-nio-http2"),
.product(name: "NIOSOCKS", package: "swift-nio-extras"),
.product(name: "Logging", package: "swift-log"),
.product(name: "Atomics", package: "swift-atomics"),
],
resources: [
.copy("Resources/self_signed_cert.pem"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

import Atomics
import Logging
import NIOConcurrencyHelpers
import NIOCore
Expand Down Expand Up @@ -165,14 +166,14 @@ extension HTTPConnectionPool.Connection.ID {
static var globalGenerator = Generator()

struct Generator {
private let atomic: NIOAtomic<Int>
private let atomic: ManagedAtomic<Int>

init() {
self.atomic = .makeAtomic(value: 0)
self.atomic = .init(0)
}

func next() -> Int {
return self.atomic.add(1)
return self.atomic.loadThenWrappingIncrement(ordering: .relaxed)
}
}
}
25 changes: 13 additions & 12 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

import AsyncHTTPClient
import Atomics
import Foundation
import Logging
import NIOConcurrencyHelpers
Expand Down Expand Up @@ -351,7 +352,7 @@ internal final class HTTPBin<RequestHandler: ChannelInboundHandler> where
private let mode: Mode
private let sslContext: NIOSSLContext?
private var serverChannel: Channel!
private let isShutdown: NIOAtomic<Bool> = .makeAtomic(value: false)
private let isShutdown = ManagedAtomic(false)
private let handlerFactory: (Int) -> (RequestHandler)

init(
Expand All @@ -376,15 +377,15 @@ internal final class HTTPBin<RequestHandler: ChannelInboundHandler> where

self.activeConnCounterHandler = ConnectionsCountHandler()

let connectionIDAtomic = NIOAtomic<Int>.makeAtomic(value: 0)
let connectionIDAtomic = ManagedAtomic(0)

self.serverChannel = try! ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.serverChannelInitializer { channel in
channel.pipeline.addHandler(self.activeConnCounterHandler)
}.childChannelInitializer { channel in
do {
let connectionID = connectionIDAtomic.add(1)
let connectionID = connectionIDAtomic.loadThenWrappingIncrement(ordering: .relaxed)

if case .refuse = mode {
throw HTTPBinError.refusedConnection
Expand Down Expand Up @@ -572,12 +573,12 @@ internal final class HTTPBin<RequestHandler: ChannelInboundHandler> where
}

func shutdown() throws {
self.isShutdown.store(true)
self.isShutdown.store(true, ordering: .relaxed)
try self.group.syncShutdownGracefully()
}

deinit {
assert(self.isShutdown.load(), "HTTPBin not shutdown before deinit")
assert(self.isShutdown.load(ordering: .relaxed), "HTTPBin not shutdown before deinit")
}
}

Expand Down Expand Up @@ -946,24 +947,24 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
final class ConnectionsCountHandler: ChannelInboundHandler {
typealias InboundIn = Channel

private let activeConns = NIOAtomic<Int>.makeAtomic(value: 0)
private let createdConns = NIOAtomic<Int>.makeAtomic(value: 0)
private let activeConns = ManagedAtomic(0)
private let createdConns = ManagedAtomic(0)

var createdConnections: Int {
self.createdConns.load()
self.createdConns.load(ordering: .relaxed)
}

var currentlyActiveConnections: Int {
self.activeConns.load()
self.activeConns.load(ordering: .relaxed)
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let channel = self.unwrapInboundIn(data)

_ = self.activeConns.add(1)
_ = self.createdConns.add(1)
_ = self.activeConns.loadThenWrappingIncrement(ordering: .relaxed)
_ = self.createdConns.loadThenWrappingIncrement(ordering: .relaxed)
channel.closeFuture.whenComplete { _ in
_ = self.activeConns.sub(1)
_ = self.activeConns.loadThenWrappingDecrement(ordering: .relaxed)
}

context.fireChannelRead(data)
Expand Down
31 changes: 16 additions & 15 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift
import Atomics
#if canImport(Network)
import Network
#endif
Expand Down Expand Up @@ -1790,16 +1791,16 @@ class HTTPClientTests: XCTestCase {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundOut = HTTPServerResponsePart

let requestNumber: NIOAtomic<Int>
let connectionNumber: NIOAtomic<Int>
let requestNumber: ManagedAtomic<Int>
let connectionNumber: ManagedAtomic<Int>

init(requestNumber: NIOAtomic<Int>, connectionNumber: NIOAtomic<Int>) {
init(requestNumber: ManagedAtomic<Int>, connectionNumber: ManagedAtomic<Int>) {
self.requestNumber = requestNumber
self.connectionNumber = connectionNumber
}

func channelActive(context: ChannelHandlerContext) {
_ = self.connectionNumber.add(1)
_ = self.connectionNumber.loadThenWrappingIncrement(ordering: .relaxed)
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand All @@ -1809,7 +1810,7 @@ class HTTPClientTests: XCTestCase {
case .head, .body:
()
case .end:
let last = self.requestNumber.add(1)
let last = self.requestNumber.loadThenWrappingIncrement(ordering: .relaxed)
switch last {
case 0, 2:
context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))),
Expand All @@ -1824,8 +1825,8 @@ class HTTPClientTests: XCTestCase {
}
}

let requestNumber = NIOAtomic<Int>.makeAtomic(value: 0)
let connectionNumber = NIOAtomic<Int>.makeAtomic(value: 0)
let requestNumber = ManagedAtomic(0)
let connectionNumber = ManagedAtomic(0)
let sharedStateServerHandler = ServerThatAcceptsThenRejects(requestNumber: requestNumber,
connectionNumber: connectionNumber)
var maybeServer: Channel?
Expand Down Expand Up @@ -1854,19 +1855,19 @@ class HTTPClientTests: XCTestCase {
XCTAssertNoThrow(try client.syncShutdown())
}

XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load())
XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load())
XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed))
XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load(ordering: .relaxed))
XCTAssertEqual(.ok, try client.get(url: url).wait().status)
XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load())
XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load())
XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed))
XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load(ordering: .relaxed))
XCTAssertThrowsError(try client.get(url: url).wait().status) { error in
XCTAssertEqual(.remoteConnectionClosed, error as? HTTPClientError)
}
XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load())
XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load())
XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed))
XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load(ordering: .relaxed))
XCTAssertEqual(.ok, try client.get(url: url).wait().status)
XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load())
XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load())
XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed))
XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load(ordering: .relaxed))
}

func testPoolClosesIdleConnections() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

@testable import AsyncHTTPClient
import Atomics
import Dispatch
import NIOConcurrencyHelpers
import NIOCore
Expand All @@ -21,14 +22,14 @@ import NIOEmbedded
/// An `EventLoopGroup` of `EmbeddedEventLoop`s.
final class EmbeddedEventLoopGroup: EventLoopGroup {
private let loops: [EmbeddedEventLoop]
private let index = NIOAtomic<Int>.makeAtomic(value: 0)
private let index = ManagedAtomic(0)

internal init(loops: Int) {
self.loops = (0..<loops).map { _ in EmbeddedEventLoop() }
}

internal func next() -> EventLoop {
let index: Int = self.index.add(1)
let index: Int = self.index.loadThenWrappingIncrement(ordering: .relaxed)
return self.loops[index % self.loops.count]
}

Expand Down

0 comments on commit 2adca4b

Please sign in to comment.