From 06151f9d50a081788f991691be3c0e6959ffffc7 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Mon, 17 Oct 2016 20:16:39 +0200 Subject: [PATCH 1/7] Proper (concurrent) transactions --- twistar/dbconfig/base.py | 11 +- twistar/tests/test_transactions.py | 246 +++++++++++++++++++++++++---- twistar/transaction.py | 161 +++++++++++++++++++ twistar/utils.py | 24 --- 4 files changed, 382 insertions(+), 60 deletions(-) create mode 100644 twistar/transaction.py diff --git a/twistar/dbconfig/base.py b/twistar/dbconfig/base.py index d2e45cb..1f66c9c 100644 --- a/twistar/dbconfig/base.py +++ b/twistar/dbconfig/base.py @@ -8,6 +8,7 @@ from twistar.registry import Registry from twistar.exceptions import ImaginaryTableError, CannotRefreshError from twistar.utils import joinWheres +from twistar.transaction import TransactionGuard class InteractionBase(object): @@ -25,7 +26,7 @@ class InteractionBase(object): def __init__(self): - self.txn = None + self.txnGuard = TransactionGuard() def logEncode(self, s, encoding='utf-8'): @@ -156,7 +157,6 @@ def _doselect(self, txn, q, args, tablename, one=False, cacheable=True): results.append(vals) return results - def insertArgsToString(self, vals): """ Convert C{{'name': value}} to an insert "values" string like C{"(%s,%s,%s)"}. @@ -327,8 +327,8 @@ def getSchema(self, tablename, txn=None): def runInteraction(self, interaction, *args, **kwargs): - if self.txn is not None: - return defer.succeed(interaction(self.txn, *args, **kwargs)) + if self.txnGuard.txn is not None: + return defer.succeed(interaction(self.txnGuard.txn, *args, **kwargs)) return Registry.DBPOOL.runInteraction(interaction, *args, **kwargs) @@ -345,8 +345,7 @@ def _doinsert(txn): if len(cols) == 0: raise ImaginaryTableError("Table %s does not exist." % tablename) vals = obj.toHash(cols, includeBlank=self.__class__.includeBlankInInsert, exclude=['id']) - self.insert(tablename, vals, txn) - obj.id = self.getLastInsertID(txn) + obj.id = self.insert(tablename, vals, txn) return obj return self.runInteraction(_doinsert) diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index 8629f99..153c804 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -1,67 +1,253 @@ +import sys +from threading import Event + from twisted.trial import unittest -from twisted.internet.defer import inlineCallbacks +from twisted.internet import reactor +from twisted.internet.error import AlreadyCalled +from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.python import threadable -from twistar.utils import transaction +from twistar.transaction import transaction from twistar.exceptions import TransactionError -from utils import initDB, tearDownDB, Registry, Transaction +from twistar.tests.utils import initDB, tearDownDB, Registry, Transaction, DBTYPE + +class TransactionTests(unittest.TestCase): -class TransactionTest(unittest.TestCase): @inlineCallbacks def setUp(self): yield initDB(self) self.config = Registry.getConfig() - @inlineCallbacks def tearDown(self): - yield tearDownDB(self) + d_tearDown = tearDownDB(self) + delayed = reactor.callLater(2, d_tearDown.cancel) + + try: + yield d_tearDown + delayed.cancel() + except: + print "db cleanup timed out" + + @inlineCallbacks + def _assertRaises(self, deferred, *excTypes): + # required for downward compatibility + + excType = None + try: + yield deferred + except: + excType, exc, tb = sys.exc_info() + + msgFormat = "Deferred expected to fail with " + ", ".join(str(expType) for expType in excTypes) + "; instead got {}" + if not excType: + self.fail(msgFormat.format("Nothing")) + else: + self.failIf(not issubclass(excType, *excTypes), msgFormat.format(excType)) + + @transaction + def test_set_cfg_txn(txn, self): + """Verify that the transaction is actually being set correctly""" + self.assertIs(txn, Registry.getConfig().txnGuard.txn) + + with transaction() as txn2: + self.assertIs(txn2, Registry.getConfig().txnGuard.txn) + self.assertIs(txn, Registry.getConfig().txnGuard.txn) @inlineCallbacks - def test_findOrCreate(self): + def test_commit(self): + barrier = Event() + @transaction @inlineCallbacks - def interaction(txn): - yield Transaction.findOrCreate(name="a name") - yield Transaction.findOrCreate(name="a name") + def trans(txn): + self.assertFalse(threadable.isInIOThread(), "Transactions must not run in main thread") + + yield Transaction(name="TEST1").save() + yield Transaction(name="TEST2").save() + + barrier.wait() # wait here to delay commit + returnValue("return value") + + d = trans() - yield interaction() count = yield Transaction.count() - self.assertEqual(count, 1) + self.assertEqual(count, 0) + barrier.set() + res = yield d + self.assertEqual(res, "return value") + + count = yield Transaction.count() + self.assertEqual(count, 2) @inlineCallbacks - def test_doubleInsert(self): + def test_rollback(self): + barrier = Event() @transaction - def interaction(txn): - def finish(trans): - return Transaction(name="unique name").save() - return Transaction(name="unique name").save().addCallback(finish) + @inlineCallbacks + def trans(txn): + yield Transaction(name="TEST1").save() + yield Transaction(name="TEST2").save() - try: - yield interaction() - except TransactionError: - pass + barrier.wait() # wait here to delay commit + raise ZeroDivisionError() + + d = trans() + + barrier.set() + yield self._assertRaises(d, ZeroDivisionError) - # there should be no transaction records stored at all count = yield Transaction.count() self.assertEqual(count, 0) - @inlineCallbacks - def test_success(self): + def test_fake_nesting_commit(self): + barrier = Event() + threadIds = [] + + @transaction + @inlineCallbacks + def trans1(txn): + threadIds.append(threadable.getThreadID()) + yield Transaction(name="TEST1").save() @transaction - def interaction(txn): - def finish(trans): - return Transaction(name="unique name two").save() - return Transaction(name="unique name").save().addCallback(finish) + @inlineCallbacks + def trans2(txn): + threadIds.append(threadable.getThreadID()) + yield trans1() + yield Transaction(name="TEST2").save() + barrier.wait() # wait here to delay commit + + d = trans2() + + count = yield Transaction.count() + self.assertEqual(count, 0) - result = yield interaction() - self.assertEqual(result.id, 2) + barrier.set() + yield d + + self.assertEqual(threadIds[0], threadIds[1], "Nested transactions don't run in same thread") count = yield Transaction.count() self.assertEqual(count, 2) + + @inlineCallbacks + def test_fake_nesting_rollback(self): + barrier = Event() + + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + txn.rollback() # should propagate to the root transaction + + @transaction + @inlineCallbacks + def trans2(txn): + yield Transaction(name="TEST2").save() + yield trans1() + + barrier.wait() # wait here to delay commit + + d = trans2() + + count = yield Transaction.count() + self.assertEqual(count, 0) + + barrier.set() + + yield d + + count = yield Transaction.count() + self.assertEqual(count, 0) + + @inlineCallbacks + def test_fake_nesting_ctxmgr(self): + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + with transaction() as txn2: + yield Transaction(name="TEST2").save() + txn2.rollback() + + yield trans1() + + count = yield Transaction.count() + self.assertEqual(count, 0) + + @inlineCallbacks + def test_parallel_transactions(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("Parallel connections are not supported by sqlite") + + threadIds = [] + + # trans1 is supposed to pass, trans2 is supposed to fail due to unique constraint + # regarding synchronization: trans1 has to start INSERT before trans2, + # because otherwise it would wait for trans2 to finish due to postgres synchronization strategy + + on_trans1_insert = Event() + barrier1, barrier2 = Event(), Event() + + @transaction + @inlineCallbacks + def trans1(txn): + threadIds.append(threadable.getThreadID()) + yield Transaction(name="TEST1").save() + on_trans1_insert.set() + barrier1.wait() # wait here to delay commit) + + @transaction + @inlineCallbacks + def trans2(txn): + threadIds.append(threadable.getThreadID()) + on_trans1_insert.wait() + yield Transaction(name="TEST1").save() + barrier2.wait() # wait here to delay commit + + d1 = trans1() + d2 = trans2() + + # commit tran1, should pass: + barrier1.set() + yield d1 + + count = yield Transaction.count() + self.assertEqual(count, 1) + + # commit trans2: + barrier2.set() + + # should fail due to unique constraint violation + yield self._assertRaises(d2, Exception) + + self.assertNotEqual(threadIds[0], threadIds[1], "Parallel transactions don't run in different threads") + + count = yield Transaction.count() + self.assertEqual(count, 1) + + @inlineCallbacks + def test_sanity_checks(self): + # Already rollbacked/commited: + @transaction + def trans1(txn): + txn.rollback() + txn.commit() + + yield self._assertRaises(trans1(), TransactionError) + + # With nesting: + @transaction + def trans2(txn): + with transaction() as txn2: + txn2.rollback() + txn.commit() + + yield self._assertRaises(trans2(), TransactionError) diff --git a/twistar/transaction.py b/twistar/transaction.py new file mode 100644 index 0000000..fa8c7a3 --- /dev/null +++ b/twistar/transaction.py @@ -0,0 +1,161 @@ +import threading +import functools + +from twisted.enterprise import adbapi +from twisted.internet.defer import inlineCallbacks, maybeDeferred, returnValue, Deferred +from twisted.python import threadable + +from twistar.registry import Registry +from twistar.exceptions import TransactionError + + +class TransactionGuard(threading.local): + + def __init__(self): + self._txn = None + + @property + def txn(self): + return self._txn + + @txn.setter + def txn(self, txn): + self._txn = txn + + +class _Transaction(object): + """Mostly borrowed from sqlalchemy and adapted to adbapi""" + + def __init__(self, parent): + self._actual_parent = parent + self.is_active = True + + if not self._parent.is_active: + raise TransactionError("Parent transaction is inactive") + + Registry.getConfig().txnGuard.txn = self + + @property + def _parent(self): + return self._actual_parent or self + + def rollback(self): + if not self._parent.is_active: + return + + Registry.getConfig().txnGuard.txn = self._actual_parent + self._do_rollback() + self.is_active = False + + def _do_rollback(self): + self._parent.rollback() + + def commit(self): + if not self._parent.is_active: + raise TransactionError("This transaction is inactive") + + Registry.getConfig().txnGuard.txn = self._actual_parent + self._do_commit() + self.is_active = False + + def _do_commit(self): + pass + + def __enter__(self): + return self + + def __exit__(self, excType, exc, traceback): + if excType is not None and issubclass(excType, Exception): + self.rollback() + elif self.is_active: + try: + self.commit() + except: + self.rollback() + raise + + def __getattr__(self, key): + return getattr(self._parent, key) + + +class _RootTransaction(adbapi.Transaction, _Transaction): + + def __init__(self, pool, connection): + adbapi.Transaction.__init__(self, pool, connection) + _Transaction.__init__(self, None) + + def close(self): + # don't set to None but errorout on subsequent access + self._cursor.close() + + def _do_rollback(self): + if self.is_active: + self._connection.rollback() + self.close() + + def _do_commit(self): + if self.is_active: + self._connection.commit() + self.close() + + def __getattr__(self, key): + return getattr(self._cursor, key) + + +class _SavepointTransaction(object): + pass + + +def _transaction_dec(func, create_transaction): + @inlineCallbacks + def _runTransaction(*args, **kwargs): + with create_transaction() as txn: + res = yield maybeDeferred(func, txn, *args, **kwargs) + returnValue(res) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + d = None # declare here so that on_result can acces it + + def on_result(success, result): + from twisted.internet import reactor + + if not success: + reactor.callFromThread(d.errback, result) + elif isinstance(result, Deferred): + result.addCallbacks(lambda res: reactor.callFromThread(d.callback, res), + lambda res: reactor.callFromThread(d.errback, res)) + else: + reactor.callFromThread(d.callback, result) + + if threadable.isInIOThread(): + d = Deferred() + thpool = Registry.DBPOOL.threadpool + thpool.callInThreadWithCallback(on_result, _runTransaction, *args, **kwargs) + return d + else: + # we are already in a db thread, so just execute the transaction + return _runTransaction(*args, **kwargs) + + return wrapper + + +def transaction(func=None): + if func is None: + conn_pool = Registry.DBPOOL + cfg = Registry.getConfig() + + if cfg.txnGuard.txn is None: + conn = conn_pool.connectionFactory(conn_pool) + return _RootTransaction(conn_pool, conn) + else: + return _Transaction(cfg.txnGuard.txn) + else: + return _transaction_dec(func, transaction) + + +def nested_transaction(func=None): + if func is None: + pass + else: + _transaction_dec(func, nested_transaction) diff --git a/twistar/utils.py b/twistar/utils.py index c2efbdb..dceb4b8 100644 --- a/twistar/utils.py +++ b/twistar/utils.py @@ -8,30 +8,6 @@ from twistar.exceptions import TransactionError -def transaction(interaction): - """ - A decorator to wrap any code in a transaction. If any exceptions are raised, all modifications - are rolled back. The function that is decorated should accept at least one argument, which is - the transaction (in case you want to operate directly on it). - """ - def _transaction(txn, args, kwargs): - config = Registry.getConfig() - config.txn = txn - # get the result of the functions *synchronously*, since this is in a transaction - try: - result = threads.blockingCallFromThread(reactor, interaction, txn, *args, **kwargs) - config.txn = None - return result - except Exception, e: - config.txn = None - raise TransactionError(str(e)) - - def wrapper(*args, **kwargs): - return Registry.DBPOOL.runInteraction(_transaction, args, kwargs) - - return wrapper - - def createInstances(props, klass): """ Create an instance of C{list} of instances of a given class From 2a46f1151fc42b449dd5bec51c8ee0831d22f4aa Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 18 Oct 2016 19:56:13 +0200 Subject: [PATCH 2/7] Fixed import --- twistar/dbobject.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/twistar/dbobject.py b/twistar/dbobject.py index cdfc7f2..1dd96fc 100644 --- a/twistar/dbobject.py +++ b/twistar/dbobject.py @@ -6,7 +6,8 @@ from twistar.registry import Registry from twistar.relationships import Relationship from twistar.exceptions import InvalidRelationshipError, DBObjectSaveError, ReferenceNotSavedError -from twistar.utils import createInstances, deferredDict, dictToWhere, transaction +from twistar.utils import createInstances, deferredDict, dictToWhere +from twistar.transaction import transaction from twistar.validation import Validator, Errors from BermiInflector.Inflector import Inflector From 861959924e1e4911888357326817976066bcb47f Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 18 Oct 2016 21:21:56 +0200 Subject: [PATCH 3/7] More sanity checks --- tox.ini | 2 ++ twistar/tests/test_transactions.py | 40 ++++++++++++++++++++-- twistar/transaction.py | 54 ++++++++++++++++++++++-------- 3 files changed, 80 insertions(+), 16 deletions(-) diff --git a/tox.ini b/tox.ini index 5b0b8b8..707b270 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,8 @@ envlist = pypy3-twisted15 [testenv] +passenv = DBTYPE +sitepackages = True deps = coverage twisted>=15.0, <16.0 diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index 153c804..b51b90e 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -3,8 +3,7 @@ from twisted.trial import unittest from twisted.internet import reactor -from twisted.internet.error import AlreadyCalled -from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, maybeDeferred from twisted.python import threadable from twistar.transaction import transaction @@ -251,3 +250,40 @@ def trans2(txn): txn.commit() yield self._assertRaises(trans2(), TransactionError) + + # Error if started in main thread: + yield self._assertRaises(maybeDeferred(transaction), TransactionError) + + # Error if rollbacked/commited in another thread: + main_thread_d = Deferred() + on_cb_added = Event() + on_callbacked = Event() + + @transaction + def trans3(txn): + def from_mainthread(do_commit): + if do_commit: + txn.commit() + else: + txn.rollback() + + main_thread_d.addCallback(from_mainthread) + on_cb_added.set() + on_callbacked.wait() # don't return (which would cause commit) until main thread executed callbacks + return main_thread_d # deferred will fail if from_mainthread() raised an Exception + + d = trans3() + on_cb_added.wait() # we need to wait for the callback to be registered otherwise it would be executed in db thread + main_thread_d.callback(True) # will commit the transaction in main thread + on_callbacked.set() + yield self._assertRaises(d, TransactionError) + + main_thread_d = Deferred() + on_cb_added.clear() + on_callbacked.clear() + + d = trans3() + on_cb_added.wait() + main_thread_d.callback(False) # will rollback the transaction in main thread + on_callbacked.set() + yield self._assertRaises(d, TransactionError) diff --git a/twistar/transaction.py b/twistar/transaction.py index fa8c7a3..627a309 100644 --- a/twistar/transaction.py +++ b/twistar/transaction.py @@ -27,8 +27,13 @@ class _Transaction(object): """Mostly borrowed from sqlalchemy and adapted to adbapi""" def __init__(self, parent): + # Transactions must not be started in the main thread + if threading.current_thread() not in Registry.DBPOOL.threadpool.threads: + raise TransactionError("Transaction must only be started in a db pool thread") + self._actual_parent = parent self.is_active = True + self._threadId = threadable.getThreadID() if not self._parent.is_active: raise TransactionError("Parent transaction is inactive") @@ -39,7 +44,15 @@ def __init__(self, parent): def _parent(self): return self._actual_parent or self + def _assertCorrectThread(self): + if threadable.getThreadID() != self._threadId: + raise TransactionError("Tried to rollback a transaction from a different thread.\n" + "Make sure that you properly use blockingCallFromThread() and\n" + "that you don't add callbacks to Deferreds which get resolved from another thread.") + def rollback(self): + self._assertCorrectThread() + if not self._parent.is_active: return @@ -51,6 +64,8 @@ def _do_rollback(self): self._parent.rollback() def commit(self): + self._assertCorrectThread() + if not self._parent.is_active: raise TransactionError("This transaction is inactive") @@ -107,26 +122,37 @@ class _SavepointTransaction(object): def _transaction_dec(func, create_transaction): - @inlineCallbacks + def _runTransaction(*args, **kwargs): - with create_transaction() as txn: - res = yield maybeDeferred(func, txn, *args, **kwargs) - returnValue(res) + txn = create_transaction() + + def on_succcess(result): + if txn.is_active: + try: + txn.commit() + except: + txn.rollback() + return result + + def on_error(fail): + if txn.is_active: + txn.rollback() + + return fail + + d = maybeDeferred(func, txn, *args, **kwargs) + d.addCallbacks(on_succcess, on_error) + d.addErrback(on_error) + return d @functools.wraps(func) def wrapper(*args, **kwargs): - d = None # declare here so that on_result can acces it + d = None # declare here so that on_result can access it - def on_result(success, result): + def on_result(success, txn_deferred): from twisted.internet import reactor - - if not success: - reactor.callFromThread(d.errback, result) - elif isinstance(result, Deferred): - result.addCallbacks(lambda res: reactor.callFromThread(d.callback, res), - lambda res: reactor.callFromThread(d.errback, res)) - else: - reactor.callFromThread(d.callback, result) + txn_deferred.addCallbacks(lambda res: reactor.callFromThread(d.callback, res), + lambda fail: reactor.callFromThread(d.errback, fail)) if threadable.isInIOThread(): d = Deferred() From 6025364f28a0de90488d705c961cba95f0826f0d Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 18 Oct 2016 22:19:18 +0200 Subject: [PATCH 4/7] Added thread_check option --- twistar/tests/test_transactions.py | 3 +++ twistar/transaction.py | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index b51b90e..0e37b0d 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -254,6 +254,9 @@ def trans2(txn): # Error if started in main thread: yield self._assertRaises(maybeDeferred(transaction), TransactionError) + # But shouldn't fail if called with thread_check=False + transaction(thread_check=False).rollback() + # Error if rollbacked/commited in another thread: main_thread_d = Deferred() on_cb_added = Event() diff --git a/twistar/transaction.py b/twistar/transaction.py index 627a309..e84e01d 100644 --- a/twistar/transaction.py +++ b/twistar/transaction.py @@ -2,7 +2,7 @@ import functools from twisted.enterprise import adbapi -from twisted.internet.defer import inlineCallbacks, maybeDeferred, returnValue, Deferred +from twisted.internet.defer import maybeDeferred, Deferred from twisted.python import threadable from twistar.registry import Registry @@ -26,9 +26,9 @@ def txn(self, txn): class _Transaction(object): """Mostly borrowed from sqlalchemy and adapted to adbapi""" - def __init__(self, parent): - # Transactions must not be started in the main thread - if threading.current_thread() not in Registry.DBPOOL.threadpool.threads: + def __init__(self, parent, thread_check=True): + # Transactions must be started in db thread unless explicitely permitted + if thread_check and threading.current_thread() not in Registry.DBPOOL.threadpool.threads: raise TransactionError("Transaction must only be started in a db pool thread") self._actual_parent = parent @@ -95,9 +95,9 @@ def __getattr__(self, key): class _RootTransaction(adbapi.Transaction, _Transaction): - def __init__(self, pool, connection): + def __init__(self, pool, connection, thread_check=True): adbapi.Transaction.__init__(self, pool, connection) - _Transaction.__init__(self, None) + _Transaction.__init__(self, None, thread_check=thread_check) def close(self): # don't set to None but errorout on subsequent access @@ -166,22 +166,22 @@ def on_result(success, txn_deferred): return wrapper -def transaction(func=None): +def transaction(func=None, thread_check=True): if func is None: conn_pool = Registry.DBPOOL cfg = Registry.getConfig() if cfg.txnGuard.txn is None: conn = conn_pool.connectionFactory(conn_pool) - return _RootTransaction(conn_pool, conn) + return _RootTransaction(conn_pool, conn, thread_check=thread_check) else: - return _Transaction(cfg.txnGuard.txn) + return _Transaction(cfg.txnGuard.txn, thread_check=thread_check) else: - return _transaction_dec(func, transaction) + return _transaction_dec(func, functools.partial(transaction, thread_check=thread_check)) -def nested_transaction(func=None): +def nested_transaction(func=None, thread_check=True): if func is None: pass else: - _transaction_dec(func, nested_transaction) + _transaction_dec(func, functools.partial(nested_transaction, thread_check=thread_check)) From 760b4c1f3e40018e097721d0328e98433fc88a55 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Wed, 19 Oct 2016 17:31:03 +0200 Subject: [PATCH 5/7] SAVEPOINT transactions --- twistar/tests/test_transactions.py | 54 +++++++++++++++++++++++++++++- twistar/transaction.py | 39 +++++++++++++++------ 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index 0e37b0d..de286bc 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -6,7 +6,7 @@ from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, maybeDeferred from twisted.python import threadable -from twistar.transaction import transaction +from twistar.transaction import transaction, nested_transaction from twistar.exceptions import TransactionError from twistar.tests.utils import initDB, tearDownDB, Registry, Transaction, DBTYPE @@ -232,6 +232,58 @@ def trans2(txn): count = yield Transaction.count() self.assertEqual(count, 1) + @inlineCallbacks + def test_savepoints_commit(self): + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + with nested_transaction(): + yield Transaction(name="TEST2").save() + yield Transaction(name="TEST3").save() + + yield trans1() + objects = yield Transaction.all() + self.assertEqual([obj.name for obj in objects], ["TEST1", "TEST2", "TEST3"]) + + @inlineCallbacks + def test_savepoints_rollback(self): + @transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST1").save() + with nested_transaction() as txn2: + yield Transaction(name="TEST2").save() + txn2.rollback() + yield Transaction(name="TEST3").save() + + yield trans1() + objects = yield Transaction.all() + self.assertEqual([obj.name for obj in objects], ["TEST1", "TEST3"]) + + @inlineCallbacks + def test_savepoints_mixed(self): + @nested_transaction + @inlineCallbacks + def trans1(txn): + yield Transaction(name="TEST3").save() + with transaction() as txn2: + yield Transaction(name="TEST4").save() + txn2.rollback() + + @transaction + @inlineCallbacks + def trans2(txn): + yield Transaction(name="TEST1").save() + with nested_transaction(): + yield Transaction(name="TEST2").save() + yield trans1() + yield Transaction(name="TEST5").save() + + yield trans2() + objects = yield Transaction.all() + self.assertEqual([obj.name for obj in objects], ["TEST1", "TEST2", "TEST5"]) + @inlineCallbacks def test_sanity_checks(self): # Already rollbacked/commited: diff --git a/twistar/transaction.py b/twistar/transaction.py index e84e01d..03d1e13 100644 --- a/twistar/transaction.py +++ b/twistar/transaction.py @@ -31,9 +31,15 @@ def __init__(self, parent, thread_check=True): if thread_check and threading.current_thread() not in Registry.DBPOOL.threadpool.threads: raise TransactionError("Transaction must only be started in a db pool thread") + if parent is None: + self._root = self + else: + self._root = parent._root + self._actual_parent = parent self.is_active = True self._threadId = threadable.getThreadID() + self._savepoint_seq = 0 if not self._parent.is_active: raise TransactionError("Parent transaction is inactive") @@ -90,7 +96,7 @@ def __exit__(self, excType, exc, traceback): raise def __getattr__(self, key): - return getattr(self._parent, key) + return getattr(self._root, key) class _RootTransaction(adbapi.Transaction, _Transaction): @@ -117,8 +123,23 @@ def __getattr__(self, key): return getattr(self._cursor, key) -class _SavepointTransaction(object): - pass +class _SavepointTransaction(_Transaction): + + def __init__(self, parent, thread_check=True): + super(_SavepointTransaction, self).__init__(parent, thread_check=thread_check) + + self._root._savepoint_seq += 1 + self._name = "twistar_savepoint_{}".format(self._root._savepoint_seq) + + self.execute("SAVEPOINT {}".format(self._name)) + + def _do_rollback(self): + if self.is_active: + self.execute("ROLLBACK TO SAVEPOINT {}".format(self._name)) + + def _do_commit(self): + if self.is_active: + self.execute("RELEASE SAVEPOINT {}".format(self._name)) def _transaction_dec(func, create_transaction): @@ -166,7 +187,7 @@ def on_result(success, txn_deferred): return wrapper -def transaction(func=None, thread_check=True): +def transaction(func=None, nested=False, thread_check=True): if func is None: conn_pool = Registry.DBPOOL cfg = Registry.getConfig() @@ -174,14 +195,12 @@ def transaction(func=None, thread_check=True): if cfg.txnGuard.txn is None: conn = conn_pool.connectionFactory(conn_pool) return _RootTransaction(conn_pool, conn, thread_check=thread_check) + elif nested: + return _SavepointTransaction(cfg.txnGuard.txn, thread_check=thread_check) else: return _Transaction(cfg.txnGuard.txn, thread_check=thread_check) else: - return _transaction_dec(func, functools.partial(transaction, thread_check=thread_check)) + return _transaction_dec(func, functools.partial(transaction, nested=nested, thread_check=thread_check)) -def nested_transaction(func=None, thread_check=True): - if func is None: - pass - else: - _transaction_dec(func, functools.partial(nested_transaction, thread_check=thread_check)) +nested_transaction = functools.partial(transaction, nested=True) From f6c1301196cf693024713a665f728d251ac7093d Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Wed, 19 Oct 2016 18:32:55 +0200 Subject: [PATCH 6/7] Documentation & check for sqlite --- twistar/tests/test_transactions.py | 9 +++++++++ twistar/transaction.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index de286bc..9f30d83 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -234,6 +234,9 @@ def trans2(txn): @inlineCallbacks def test_savepoints_commit(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("SAVEPOINT acts weird with sqlite, needs further inspection.") + @transaction @inlineCallbacks def trans1(txn): @@ -248,6 +251,9 @@ def trans1(txn): @inlineCallbacks def test_savepoints_rollback(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("SAVEPOINT acts weird with sqlite, needs further inspection.") + @transaction @inlineCallbacks def trans1(txn): @@ -263,6 +269,9 @@ def trans1(txn): @inlineCallbacks def test_savepoints_mixed(self): + if DBTYPE == "sqlite": + raise unittest.SkipTest("SAVEPOINT acts weird with sqlite, needs further inspection.") + @nested_transaction @inlineCallbacks def trans1(txn): diff --git a/twistar/transaction.py b/twistar/transaction.py index 03d1e13..e9e3cd7 100644 --- a/twistar/transaction.py +++ b/twistar/transaction.py @@ -188,6 +188,38 @@ def on_result(success, txn_deferred): def transaction(func=None, nested=False, thread_check=True): + """Starts a new transaction. + + A Transaction object returned by this function can be used as a context manager, + which will atomatically be commited or rolledback if an exception is raised. + + Transactions must only be used in db threads. This behaviour can be overriden by setting the + 'thread_check' to False, allowing transactions to be started in arbitrary threads which is + useful to e.g simplify testcases. + + If this function is used as decorator, the decorated function will be executed in a db thread and + gets the Transaction passed as first argument. Decorated functions are allowed to return Deferreds. + E.g: + @transaction + def someFunc(txn, param1): + # Runs in a db thread + + d = someFunc(1) # will be calledback (in mainthread) when someFunc returns + + You have to make sure, that you use blockingCallFromThread() or use synchronization if you need to + interact with code which runs in the mainthread. Also care has to be taken when waiting for Deferreds. + You must assure that the callbacks will be invoked from the db thread. + + Per default transactions can be nested: Commiting such a "nested" transaction will simply do nothing, + but a rollback on it will rollback the outermost transaction. This allow creation of functions which will + either create a new transaction or will participate in an already ongoing tranaction which is handy for library code. + + SAVEPOINT transactions can be used by either setting the 'nested' flag to true or by calling the 'nested_transaction' function. + """ + if nested and Registry.DBPOOL.dbapi.__name__ == "sqlite3": + # nees some modification on our side, see: http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + raise NotImplementedError("sqlite currently not supported") + if func is None: conn_pool = Registry.DBPOOL cfg = Registry.getConfig() From 8310e31dcca336768d816ac2e491cb93f56ce6d9 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Thu, 20 Oct 2016 02:17:18 +0200 Subject: [PATCH 7/7] More test, fixed linting --- twistar/tests/test_transactions.py | 31 +++++++++++++++++++++++++++++- twistar/transaction.py | 3 ++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/twistar/tests/test_transactions.py b/twistar/tests/test_transactions.py index 9f30d83..a0c329e 100644 --- a/twistar/tests/test_transactions.py +++ b/twistar/tests/test_transactions.py @@ -3,7 +3,7 @@ from twisted.trial import unittest from twisted.internet import reactor -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, maybeDeferred +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, maybeDeferred, DeferredList from twisted.python import threadable from twistar.transaction import transaction, nested_transaction @@ -232,6 +232,35 @@ def trans2(txn): count = yield Transaction.count() self.assertEqual(count, 1) + @inlineCallbacks + def test_parallel_massive(self): + # Make sure that everything works alright even when starting a massive amount of parallel transactions + if DBTYPE == "sqlite": + raise unittest.SkipTest("Parallel connections are not supported by sqlite") + + N = 100 + + @transaction + @inlineCallbacks + def trans(txn, i): + yield Transaction(name=str(i)).save() + if i % 2 == 1: + txn.rollback() + else: + txn.commit() + + deferreds = [trans(i) for i in range(N)] + + results = yield DeferredList(deferreds) + self.assertTrue(all(success for success, result in results)) + + objects = yield Transaction.all() + actual = sorted(int(obj.name) for obj in objects) + actual = [str(i) for i in actual] + expected = [str(i) for i in range(0, N, 2)] + + self.assertEquals(actual, expected) + @inlineCallbacks def test_savepoints_commit(self): if DBTYPE == "sqlite": diff --git a/twistar/transaction.py b/twistar/transaction.py index e9e3cd7..67397c0 100644 --- a/twistar/transaction.py +++ b/twistar/transaction.py @@ -217,7 +217,8 @@ def someFunc(txn, param1): SAVEPOINT transactions can be used by either setting the 'nested' flag to true or by calling the 'nested_transaction' function. """ if nested and Registry.DBPOOL.dbapi.__name__ == "sqlite3": - # nees some modification on our side, see: http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + # needs some modification on our side, see: + # http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl raise NotImplementedError("sqlite currently not supported") if func is None: