diff --git a/src/Connection.php b/src/Connection.php index 9d0dc3f..5f162ec 100644 --- a/src/Connection.php +++ b/src/Connection.php @@ -260,4 +260,13 @@ public function dropTable(string $name): PromiseInterface; * @throws TableNotFoundException */ public function truncateTable(string $name): PromiseInterface; + + /** + * @return \React\Promise\PromiseInterface<\Drift\DBAL\SingleConnection> + */ + public function startTransaction(): PromiseInterface; + + public function commitTransaction(SingleConnection $connection): PromiseInterface; + + public function rollbackTransaction(SingleConnection $connection): PromiseInterface; } diff --git a/src/ConnectionPool.php b/src/ConnectionPool.php index 149cc5b..f51f614 100644 --- a/src/ConnectionPool.php +++ b/src/ConnectionPool.php @@ -23,31 +23,51 @@ use Doctrine\DBAL\Query\QueryBuilder; use Doctrine\DBAL\Schema\Schema; use Drift\DBAL\Driver\Driver; +use React\Promise\Deferred; use React\Promise\PromiseInterface; +use RuntimeException; +use SplObjectStorage; + +use function React\Promise\resolve; /** * Class ConnectionPool. */ class ConnectionPool implements Connection, ConnectionPoolInterface { - private array $connections; + + /** + * @var SplObjectStorage|array<\Drift\DBAL\Connection, \Drift\DBAL\ConnectionWorker> $connections + */ + private SplObjectStorage $connections; + + /** + * @var SplObjectStorage|Deferred[] $deferreds + */ + private SplObjectStorage $deferreds; /** * Connection constructor. * - * @param ConnectionWorker[] $connections + * @param ConnectionWorker[] $workers */ - private function __construct(array $connections) + private function __construct(array $workers) { - $this->connections = $connections; + $this->connections = new SplObjectStorage; + + foreach ($workers as $worker) { + $this->connections->attach($worker->getConnection(), $worker); + } + + $this->deferreds = new SplObjectStorage; } /** * Create new connection. * - * @param Driver $driver - * @param Credentials $credentials - * @param AbstractPlatform $platform + * @param Driver $driver + * @param Credentials $credentials + * @param AbstractPlatform $platform * @param ConnectionOptions|null $options * * @return Connection @@ -64,6 +84,9 @@ public static function create( $numberOfConnections = $options->getNumberOfConnections(); } + // Since using transactions with a single connection in an asynchronous environment + // probably doesn't do what you want it to do* we explicitly disallow it. + if ($numberOfConnections <= 1) { return SingleConnection::create( $driver, @@ -73,27 +96,28 @@ public static function create( ); } - $connections = []; + $workers = []; for ($i = 0; $i < $numberOfConnections; ++$i) { - $connections[] = new ConnectionWorker( + $workers[] = new ConnectionWorker( SingleConnection::create( clone $driver, $credentials, $platform, - $options + $options, + true ), $i ); } - return new self($connections); + return new self($workers); } /** * Create new connection. * - * @param Driver $driver - * @param Credentials $credentials - * @param AbstractPlatform $platform + * @param Driver $driver + * @param Credentials $credentials + * @param AbstractPlatform $platform * @param ConnectionOptions|null $options * * @return Connection @@ -115,41 +139,49 @@ public static function createConnected( */ public function getDriverNamespace(): string { - return $this - ->bestConnection() - ->getConnection() - ->getDriverNamespace(); + $connections = clone $this->connections; + $worker = $connections->current(); + return $worker->getDriverNamespace(); } /** * @param bool $increaseJobs * - * @return ConnectionWorker + * @return PromiseInterface */ - private function bestConnection(bool $increaseJobs = false): ConnectionWorker - { + private function bestConnectionWorker( + bool $increaseJobs = false + ): PromiseInterface { $minJobs = 1000000000; - $minJobsConnection = null; - foreach ($this->connections as $i => $connection) { - if ($connection->getJobs() < $minJobs) { - $minJobs = $connection->getJobs(); - $minJobsConnection = $i; + $bestConnection = null; + + foreach ($this->connections as $connection) { + $worker = $this->connections->getInfo(); + if ($worker->getJobs() < $minJobs && !$worker->isLeased()) { + $minJobs = $worker->getJobs(); + $bestConnection = $worker; } } - if ($increaseJobs) { - $this->connections[$minJobsConnection]->startJob(); + if ($bestConnection !== null) { + if ($increaseJobs) { + $bestConnection->startJob(); + } + + return resolve($bestConnection); } - return $this->connections[$minJobsConnection]; - } + // We may want to introduce some logging here. + // fwrite(STDOUT, 'no more workers' . PHP_EOL); - /** - * @return ConnectionWorker - */ - private function firstConnection(): ConnectionWorker - { - return $this->connections[0]; + $deferred = new Deferred; + $this->deferreds->attach($deferred); + + if ($increaseJobs) { + $deferred->promise()->then(fn(ConnectionWorker $w) => $w->startJob()); + } + + return $deferred->promise(); } /** @@ -159,12 +191,15 @@ private function firstConnection(): ConnectionWorker */ private function executeInBestConnection(callable $callable): PromiseInterface { - $connectionWorker = $this->bestConnection(true); - - return $callable($connectionWorker->getConnection()) - ->then(function ($whatever) use ($connectionWorker) { - $connectionWorker->stopJob(); - + $worker = null; + + return $this->bestConnectionWorker(true) + ->then(function (ConnectionWorker $connectionWorker) use ($callable, &$worker) { + $worker = $connectionWorker; + return $callable($worker->getConnection()); + }) + ->then(function ($whatever) use (&$worker) { + $worker->stopJob(); return $whatever; }); } @@ -175,7 +210,7 @@ private function executeInBestConnection(callable $callable): PromiseInterface public function connect(?ConnectionOptions $options = null) { foreach ($this->connections as $connection) { - $connection->getConnection()->connect($options); + $connection->connect($options); } } @@ -185,7 +220,7 @@ public function connect(?ConnectionOptions $options = null) public function close() { foreach ($this->connections as $connection) { - $connection->getConnection()->close(); + $connection->close(); } } @@ -198,9 +233,13 @@ public function close() */ public function createQueryBuilder(): QueryBuilder { - return $this - ->firstConnection() - ->getConnection() + // We clone the worker storage because we probably can't rely on + // SplObjectStorage::current(...) working reliably in an + // asynchronous environment. + $workers = clone $this->connections; + $workers->rewind(); + + return $workers->current() ->createQueryBuilder(); } @@ -222,7 +261,7 @@ public function query(QueryBuilder $queryBuilder): PromiseInterface * Query by sql and parameters. * * @param string $sql - * @param array $parameters + * @param array $parameters * * @return PromiseInterface */ @@ -271,7 +310,7 @@ public function executeSchema(Schema $schema): PromiseInterface * connection->findOneById('table', ['id' => 1]); * * @param string $table - * @param array $where + * @param array $where * * @return PromiseInterface */ @@ -290,7 +329,7 @@ public function findOneBy( * connection->findBy('table', ['id' => 1]); * * @param string $table - * @param array $where + * @param array $where * * @return PromiseInterface */ @@ -305,7 +344,7 @@ public function findBy( /** * @param string $table - * @param array $values + * @param array $values * * @return PromiseInterface */ @@ -320,7 +359,7 @@ public function insert( /** * @param string $table - * @param array $values + * @param array $values * * @return PromiseInterface * @@ -337,8 +376,8 @@ public function delete( /** * @param string $table - * @param array $id - * @param array $values + * @param array $id + * @param array $values * * @return PromiseInterface * @@ -356,8 +395,8 @@ public function update( /** * @param string $table - * @param array $id - * @param array $values + * @param array $id + * @param array $values * * @return PromiseInterface * @@ -385,9 +424,9 @@ public function upsert( * First field is considered as primary key. * * @param string $name - * @param array $fields - * @param array $extra - * @param bool $autoincrementId + * @param array $fields + * @param array $extra + * @param bool $autoincrementId * * @return PromiseInterface * @@ -400,7 +439,12 @@ public function createTable( array $extra = [], bool $autoincrementId = false ): PromiseInterface { - return $this->executeInBestConnection(function (Connection $connection) use ($name, $fields, $extra, $autoincrementId) { + return $this->executeInBestConnection(function (Connection $connection) use ( + $name, + $fields, + $extra, + $autoincrementId + ) { return $connection->createTable($name, $fields, $extra, $autoincrementId); }); } @@ -438,8 +482,62 @@ public function truncateTable(string $name): PromiseInterface * * @return ConnectionWorker[] */ - public function getWorkers(): array + public function getConnections(): array + { + $workers = []; + + $connections = clone $this->connections; + foreach ($connections as $connection) { + $workers[] = $connections->getInfo(); + } + + return $workers; + } + + public function startTransaction(): PromiseInterface + { + return $this->bestConnectionWorker() + ->then(function (ConnectionWorker $worker) { + $worker->setLeased(true); + + $connection = $worker->getConnection(); + + if (!$connection instanceof SingleConnection) { + throw new RuntimeException('connection must be instance of ' . SingleConnection::class); + } + + $connection->startTransaction(); + + return resolve($connection); + }); + } + + public function commitTransaction(SingleConnection $connection): PromiseInterface { - return $this->connections; + return $connection->commitTransaction($connection) + ->always(fn() => $this->releaseConnection($connection)); } + + public function rollbackTransaction(SingleConnection $connection): PromiseInterface + { + return $connection->rollbackTransaction($connection) + ->always(fn() => $this->releaseConnection($connection)); + } + + private function releaseConnection(SingleConnection $connection): PromiseInterface + { + if (count($this->deferreds) === 0) { + /** @var \Drift\DBAL\ConnectionWorker $worker */ + $worker = $this->connections[$connection]; + $worker->setLeased(false); + return resolve(); + } + + $deferred = $this->deferreds->current(); + $this->deferreds->detach($deferred); + + $deferred->resolve($this->connections[$connection]); + return resolve(); + } + } diff --git a/src/ConnectionPoolInterface.php b/src/ConnectionPoolInterface.php index d90eb31..b660542 100644 --- a/src/ConnectionPoolInterface.php +++ b/src/ConnectionPoolInterface.php @@ -9,5 +9,5 @@ interface ConnectionPoolInterface * * @return ConnectionWorker[] */ - public function getWorkers(): array; + public function getConnections(): array; } diff --git a/src/ConnectionWorker.php b/src/ConnectionWorker.php index 3537d7b..4014d8b 100644 --- a/src/ConnectionWorker.php +++ b/src/ConnectionWorker.php @@ -23,6 +23,7 @@ class ConnectionWorker private Connection $connection; private int $id; private int $jobs; + private bool $leased = false; /** * @param Connection $connection @@ -68,4 +69,16 @@ public function getJobs(): int { return $this->jobs; } + + public function isLeased(): bool + { + return $this->leased; + } + + public function setLeased(bool $leased): ConnectionWorker + { + $this->leased = $leased; + return $this; + } + } diff --git a/src/Driver/AbstractDriver.php b/src/Driver/AbstractDriver.php index a85d79e..c50c053 100644 --- a/src/Driver/AbstractDriver.php +++ b/src/Driver/AbstractDriver.php @@ -53,4 +53,6 @@ protected function createInsertQuery(QueryBuilder $queryBuilder, string $table, )) ->setParameters(array_values($values)); } + + } diff --git a/src/Driver/Driver.php b/src/Driver/Driver.php index 9a30568..b55f6ce 100644 --- a/src/Driver/Driver.php +++ b/src/Driver/Driver.php @@ -62,4 +62,10 @@ public function insert( * @return void */ public function close(): void; + + public function startTransaction(): PromiseInterface; + + public function commitTransaction(): PromiseInterface; + + public function rollbackTransaction(): PromiseInterface; } diff --git a/src/Driver/Mysql/MysqlDriver.php b/src/Driver/Mysql/MysqlDriver.php index 1db6495..c85733e 100644 --- a/src/Driver/Mysql/MysqlDriver.php +++ b/src/Driver/Mysql/MysqlDriver.php @@ -94,4 +94,20 @@ public function close(): void { $this->connection->close(); } + + public function startTransaction(): PromiseInterface + { + return $this->query('START TRANSACTION', []); + } + + public function commitTransaction(): PromiseInterface + { + return $this->query('COMMIT', []); + } + + public function rollbackTransaction(): PromiseInterface + { + return $this->query('ROLLBACK', []); + } + } diff --git a/src/Driver/PostgreSQL/PostgreSQLDriver.php b/src/Driver/PostgreSQL/PostgreSQLDriver.php index 2ddf5bd..87e67ad 100644 --- a/src/Driver/PostgreSQL/PostgreSQLDriver.php +++ b/src/Driver/PostgreSQL/PostgreSQLDriver.php @@ -176,4 +176,19 @@ public function close(): void ->connection ->disconnect(); } + + public function startTransaction(): PromiseInterface + { + return $this->query('START TRANSACTION', []); + } + + public function commitTransaction(): PromiseInterface + { + return $this->query('COMMIT', []); + } + + public function rollbackTransaction(): PromiseInterface + { + return $this->query('ROLLBACK', []); + } } diff --git a/src/Driver/SQLite/SQLiteDriver.php b/src/Driver/SQLite/SQLiteDriver.php index c313e81..bf58d6f 100644 --- a/src/Driver/SQLite/SQLiteDriver.php +++ b/src/Driver/SQLite/SQLiteDriver.php @@ -90,4 +90,19 @@ public function close(): void { $this->database->close(); } + + public function startTransaction(): PromiseInterface + { + return $this->query('BEGIN TRANSACTION', []); + } + + public function commitTransaction(): PromiseInterface + { + return $this->query('COMMIT', []); + } + + public function rollbackTransaction(): PromiseInterface + { + return $this->query('ROLLBACK', []); + } } diff --git a/src/SingleConnection.php b/src/SingleConnection.php index 82b4205..e16d37e 100644 --- a/src/SingleConnection.php +++ b/src/SingleConnection.php @@ -27,8 +27,10 @@ use Drift\DBAL\Mock\MockedDriver; use React\EventLoop\Loop; use React\EventLoop\TimerInterface; +use RuntimeException; use function React\Promise\map; use React\Promise\PromiseInterface; +use function React\Promise\resolve; class SingleConnection implements Connection { @@ -36,6 +38,7 @@ class SingleConnection implements Connection private Credentials $credentials; private AbstractPlatform $platform; private ?TimerInterface $keepAliveTimer = null; + private bool $allowTransactions; /** * Connection constructor. @@ -47,11 +50,13 @@ class SingleConnection implements Connection private function __construct( Driver $driver, Credentials $credentials, - AbstractPlatform $platform + AbstractPlatform $platform, + bool $allowTransactions = false ) { $this->driver = $driver; $this->credentials = $credentials; $this->platform = $platform; + $this->allowTransactions = $allowTransactions; } public function __destruct() @@ -75,9 +80,10 @@ public static function create( Driver $driver, Credentials $credentials, AbstractPlatform $platform, - ?ConnectionOptions $options = null + ?ConnectionOptions $options = null, + bool $allowTransactions = false ): Connection { - return new self($driver, $credentials, $platform); + return new self($driver, $credentials, $platform, $allowTransactions); } /** @@ -524,4 +530,34 @@ private function applyWhereClausesFromArray( $queryBuilder->setParameters($params); } + + public function startTransaction(): PromiseInterface + { + if (!$this->allowTransactions) { + throw new RuntimeException('starting a transaction in a SingleConnection is not allowed'); + } + + return $this->driver->startTransaction() + ->then(fn(...$args) => $this); + } + + public function commitTransaction(SingleConnection $connection): PromiseInterface + { + if (!$this->allowTransactions) { + throw new RuntimeException('starting a transaction in a SingleConnection is not allowed'); + } + + return $this->driver->commitTransaction() + ->then(fn(...$args) => $this); + } + + public function rollbackTransaction(SingleConnection $connection): PromiseInterface + { + if (!$this->allowTransactions) { + throw new RuntimeException('starting a transaction in a SingleConnection is not allowed'); + } + + return $this->driver->rollbackTransaction() + ->then(fn(...$args) => $this); + } } diff --git a/tests/ConnectionTest.php b/tests/ConnectionTest.php index 7c1c82c..d7b93da 100644 --- a/tests/ConnectionTest.php +++ b/tests/ConnectionTest.php @@ -15,7 +15,6 @@ namespace Drift\DBAL\Tests; -use function Clue\React\Block\await; use Doctrine\DBAL\Exception as DBALException; use Doctrine\DBAL\Exception\InvalidArgumentException; use Doctrine\DBAL\Exception\TableExistsException; @@ -23,17 +22,22 @@ use Doctrine\DBAL\Exception\UniqueConstraintViolationException; use Drift\DBAL\Connection; use Drift\DBAL\Result; +use Drift\DBAL\SingleConnection; use PHPUnit\Framework\TestCase; use React\EventLoop\Factory; use React\EventLoop\LoopInterface; -use function React\Promise\all; use React\Promise\PromiseInterface; +use RuntimeException; + +use function Clue\React\Block\await; +use function React\Promise\all; /** * Class ConnectionTest. */ abstract class ConnectionTest extends TestCase { + /** * The timeout is used to prevent tests from endless waiting. * Consider this amount of seconds as a reasonable timeout @@ -50,7 +54,7 @@ abstract protected function getConnection(LoopInterface $loop): Connection; /** * @param Connection $connection - * @param bool $autoincrementedId + * @param bool $autoincrementedId * * @return PromiseInterface */ @@ -94,7 +98,7 @@ protected function dropInfrastructure(Connection $connection): PromiseInterface /** * @param Connection $connection - * @param bool $autoincrementedId + * @param bool $autoincrementedId * * @return PromiseInterface */ @@ -198,10 +202,10 @@ public function testMultipleRows() ->then(function (Connection $connection) { return $connection ->insert('test', [ - 'id' => '1', - 'field1' => 'val11', - 'field2' => 'val12', - ]) + 'id' => '1', + 'field1' => 'val11', + 'field2' => 'val12', + ]) ->then(function () use ($connection) { return $connection->insert('test', [ 'id' => '2', @@ -222,13 +226,13 @@ public function testMultipleRows() return $connection ->query($queryBuilder - ->select('*') - ->from('test', 't') - ->where($queryBuilder->expr()->orX( - $queryBuilder->expr()->eq('t.id', '?'), - $queryBuilder->expr()->eq('t.id', '?') - )) - ->setParameters(['1', '2'])); + ->select('*') + ->from('test', 't') + ->where($queryBuilder->expr()->orX( + $queryBuilder->expr()->eq('t.id', '?'), + $queryBuilder->expr()->eq('t.id', '?') + )) + ->setParameters(['1', '2'])); }) ->then(function (Result $result) { $this->assertCount(2, $result->fetchAllRows()); @@ -288,45 +292,45 @@ public function testFindShortcut() 'field2' => 'val2', ]), ]) - ->then(function () use ($connection) { - return all([ - $connection->findOneBy('test', [ + ->then(function () use ($connection) { + return all([ + $connection->findOneBy('test', [ + 'id' => '1', + ]), + $connection->findOneBy('test', [ + 'id' => '999', + ]), + $connection->findBy('test', [ + 'field1' => 'val1', + ]), + ]); + }) + ->then(function (array $results) { + $this->assertEquals($results[0], [ 'id' => '1', - ]), - $connection->findOneBy('test', [ - 'id' => '999', - ]), - $connection->findBy('test', [ 'field1' => 'val1', - ]), - ]); - }) - ->then(function (array $results) { - $this->assertEquals($results[0], [ - 'id' => '1', - 'field1' => 'val1', - 'field2' => 'val1', - ]); + 'field2' => 'val1', + ]); - $this->assertNull($results[1]); - $listResults = $results[2]; - usort($listResults, function ($a1, $a2) { - return $a1['id'] > $a2['id']; + $this->assertNull($results[1]); + $listResults = $results[2]; + usort($listResults, function ($a1, $a2) { + return $a1['id'] > $a2['id']; + }); + + $this->assertSame($listResults, [ + [ + 'id' => '1', + 'field1' => 'val1', + 'field2' => 'val1', + ], + [ + 'id' => '2', + 'field1' => 'val1', + 'field2' => 'val2', + ], + ]); }); - - $this->assertSame($listResults, [ - [ - 'id' => '1', - 'field1' => 'val1', - 'field2' => 'val1', - ], - [ - 'id' => '2', - 'field1' => 'val1', - 'field2' => 'val2', - ], - ]); - }); }); await($promise, $loop, self::MAX_TIMEOUT); @@ -360,7 +364,7 @@ public function testFindByNullValue() ]); }); - list($result0, $result1) = await($promise, $loop, self::MAX_TIMEOUT); + [$result0, $result1] = await($promise, $loop, self::MAX_TIMEOUT); $this->assertEquals('1', $result0['id']); $this->assertCount(1, $result1); @@ -644,4 +648,103 @@ public function testCloseConnection() await($promise, $loop, self::MAX_TIMEOUT); } + + public function testTransactionRollback(): void + { + $loop = $this->createLoop(); + $connection = $this->getConnection($loop); + + if ($connection instanceof SingleConnection) { + $this->expectException(RuntimeException::class); + await($connection->startTransaction()); + return; + } + + await($this->resetInfrastructure($connection, true)); + + $transaction = null; + + $promise = $connection->startTransaction() + ->then(function (Connection $c) use (&$transaction) { + $transaction = $c; + + return $transaction->insert('test', [ + 'field1' => 'transaction', + 'field2' => 'rollback', + ]); + }) + ->then(function () use (&$connection) { + // This should happen on a different connection because + // the transaction one has been leased. + return $connection->insert('test', [ + 'field1' => 'other', + 'field2' => 'coroutine', + ]); + }) + ->then(function () use ($connection, &$transaction) { + /** @var SingleConnection $transaction */ + return $connection->rollbackTransaction($transaction); + }); + + await($promise); + + $queryBuilder = $connection->createQueryBuilder() + ->select('*') + ->from('test') + ->where("field1 = 'transaction' and field2 = 'rollback'"); + /** @var Result $result */ + $result = await($connection->query($queryBuilder)); + + self::assertEmpty($result->fetchAllRows()); + + $queryBuilder = $connection->createQueryBuilder() + ->select('*') + ->from('test') + ->where("field1 = 'other' and field2 = 'coroutine'"); + /** @var Result $result */ + $result = await($connection->query($queryBuilder)); + + self::assertEquals(1, $result->fetchCount()); + } + + public function testTransactionCommit(): void + { + $loop = $this->createLoop(); + $connection = $this->getConnection($loop); + + if ($connection instanceof SingleConnection) { + $this->expectException(RuntimeException::class); + await($connection->startTransaction()); + return; + } + + await($this->resetInfrastructure($connection, true)); + + $transaction = null; + + $promise = $connection->startTransaction() + ->then(function (Connection $c) use (&$transaction) { + $transaction = $c; + + return $transaction->insert('test', [ + 'field1' => 'transaction', + 'field2' => 'commit', + ]); + }) + ->then(function () use ($connection, &$transaction) { + /** @var SingleConnection $transaction */ + return $connection->commitTransaction($transaction); + }); + + await($promise); + + $queryBuilder = $connection->createQueryBuilder() + ->select('*') + ->from('test') + ->where("field1 = 'transaction' and field2 = 'commit'"); + /** @var Result $result */ + $result = await($connection->query($queryBuilder)); + + self::assertEquals(1, $result->fetchCount()); + } } diff --git a/tests/SQLiteConnectionTest.php b/tests/SQLiteConnectionTest.php index c83503a..e1b6937 100644 --- a/tests/SQLiteConnectionTest.php +++ b/tests/SQLiteConnectionTest.php @@ -32,7 +32,7 @@ class SQLiteConnectionTest extends ConnectionTest */ public function getConnection(LoopInterface $loop): Connection { - $mysqlPlatform = new SqlitePlatform(); + $platform = new SqlitePlatform(); return SingleConnection::createConnected(new SQLiteDriver( $loop @@ -42,6 +42,6 @@ public function getConnection(LoopInterface $loop): Connection 'root', 'root', ':memory:' - ), $mysqlPlatform); + ), $platform); } }