@@ -8,11 +8,17 @@ import org.jetbrains.exposed.sql.addLogger
8
8
import org.jetbrains.exposed.sql.exposedLogger
9
9
import org.jetbrains.exposed.sql.transactions.TransactionManager
10
10
import org.jetbrains.exposed.sql.transactions.transactionManager
11
+ import org.springframework.jdbc.datasource.ConnectionHandle
12
+ import org.springframework.jdbc.datasource.ConnectionHolder
13
+ import org.springframework.jdbc.datasource.JdbcTransactionObjectSupport
14
+ import org.springframework.transaction.CannotCreateTransactionException
11
15
import org.springframework.transaction.TransactionDefinition
12
16
import org.springframework.transaction.TransactionSystemException
13
17
import org.springframework.transaction.support.AbstractPlatformTransactionManager
14
18
import org.springframework.transaction.support.DefaultTransactionStatus
15
- import org.springframework.transaction.support.SmartTransactionObject
19
+ import org.springframework.transaction.support.TransactionSynchronizationManager
20
+ import org.springframework.transaction.support.TransactionSynchronizationUtils
21
+ import java.sql.Connection
16
22
import javax.sql.DataSource
17
23
18
24
/* *
@@ -25,13 +31,12 @@ import javax.sql.DataSource
25
31
* @sample org.jetbrains.exposed.spring.TestConfig
26
32
*/
27
33
class SpringTransactionManager (
28
- dataSource : DataSource ,
34
+ private val dataSource : DataSource ,
29
35
databaseConfig : DatabaseConfig = DatabaseConfig {},
30
36
private val showSql : Boolean = false ,
31
37
) : AbstractPlatformTransactionManager() {
32
38
33
39
private var _database : Database
34
-
35
40
private var _transactionManager : TransactionManager
36
41
37
42
private val threadLocalTransactionManager: TransactionManager
@@ -63,16 +68,23 @@ class SpringTransactionManager(
63
68
manager = threadLocalTransactionManager,
64
69
outerManager = outerManager,
65
70
outerTransaction = outer,
66
- )
71
+ ).apply {
72
+ setConnectionHolder(
73
+ TransactionSynchronizationManager .getResource(dataSource) as ? ConnectionHolder
74
+ )
75
+ }
67
76
}
68
77
69
78
override fun doSuspend (transaction : Any ): Any {
70
79
val trxObject = transaction as ExposedTransactionObject
71
80
val currentManager = trxObject.manager
72
81
82
+ trxObject.setConnectionHolder(null )
83
+
73
84
return SuspendedObject (
74
85
transaction = currentManager.currentOrNull() as Transaction ,
75
86
manager = currentManager,
87
+ connectionHolder = TransactionSynchronizationManager .unbindResource(dataSource) as ConnectionHolder ,
76
88
).apply {
77
89
currentManager.bindTransactionToThread(null )
78
90
TransactionManager .resetCurrent(null )
@@ -84,11 +96,13 @@ class SpringTransactionManager(
84
96
85
97
TransactionManager .resetCurrent(suspendedObject.manager)
86
98
threadLocalTransactionManager.bindTransactionToThread(suspendedObject.transaction)
99
+ TransactionSynchronizationManager .bindResource(dataSource, suspendedObject.connectionHolder)
87
100
}
88
101
89
102
private data class SuspendedObject (
90
103
val transaction : Transaction ,
91
- val manager : TransactionManager
104
+ val manager : TransactionManager ,
105
+ val connectionHolder : ConnectionHolder ,
92
106
)
93
107
94
108
override fun isExistingTransaction (transaction : Any ): Boolean {
@@ -102,7 +116,7 @@ class SpringTransactionManager(
102
116
val currentTransactionManager = trxObject.manager
103
117
TransactionManager .resetCurrent(threadLocalTransactionManager)
104
118
105
- currentTransactionManager.newTransaction(
119
+ val transaction = currentTransactionManager.newTransaction(
106
120
isolation = definition.isolationLevel,
107
121
readOnly = definition.isReadOnly,
108
122
outerTransaction = currentTransactionManager.currentOrNull()
@@ -115,6 +129,24 @@ class SpringTransactionManager(
115
129
addLogger(StdOutSqlLogger )
116
130
}
117
131
}
132
+
133
+ @Suppress(" TooGenericExceptionCaught" )
134
+ try {
135
+ if (! trxObject.hasConnectionHolder()) {
136
+ trxObject.connectionHolder = ConnectionHolder (ExposedConnectionHandle (transaction))
137
+ trxObject.isNewConnectionHolder = true
138
+ }
139
+
140
+ trxObject.getConnectionHolder().isSynchronizedWithTransaction = true
141
+
142
+ // Bind the connection holder to the thread.
143
+ if (trxObject.isNewConnectionHolder) {
144
+ TransactionSynchronizationManager .bindResource(dataSource, trxObject.getConnectionHolder())
145
+ }
146
+ } catch (ex: Throwable ) {
147
+ trxObject.setConnectionHolder(null )
148
+ throw CannotCreateTransactionException (" Could not open JDBC Connection for transaction" , ex)
149
+ }
118
150
}
119
151
120
152
override fun doCommit (status : DefaultTransactionStatus ) {
@@ -135,8 +167,13 @@ class SpringTransactionManager(
135
167
trxObject.cleanUpTransactionIfIsPossible {
136
168
closeStatementsAndConnections(it)
137
169
}
138
-
139
170
trxObject.setCurrentToOuter()
171
+
172
+ if (trxObject.isNewConnectionHolder) {
173
+ TransactionSynchronizationManager .unbindResource(dataSource)
174
+ trxObject.getConnectionHolder().released()
175
+ }
176
+ trxObject.getConnectionHolder().clear()
140
177
}
141
178
142
179
private fun closeStatementsAndConnections (transaction : Transaction ) {
@@ -169,9 +206,17 @@ class SpringTransactionManager(
169
206
val manager : TransactionManager ,
170
207
val outerManager : TransactionManager ,
171
208
private val outerTransaction : Transaction ? ,
172
- ) : SmartTransactionObject {
209
+ ) : JdbcTransactionObjectSupport() {
173
210
174
211
private var isRollback: Boolean = false
212
+ var isNewConnectionHolder: Boolean = false
213
+
214
+ // the Java base class has asymmetric nullability for its connectionHolder getters
215
+ // and setters - which confuses the Kotlin compiler and makes it produce warnings/suggestions
216
+ // regardless of which style you choose. To avoid it we override this.
217
+ override fun setConnectionHolder (connectionHolder : ConnectionHolder ? ) {
218
+ super .setConnectionHolder(connectionHolder)
219
+ }
175
220
176
221
fun cleanUpTransactionIfIsPossible (block : (transaction: Transaction ) -> Unit ) {
177
222
val currentTransaction = getCurrentTransaction()
@@ -212,7 +257,15 @@ class SpringTransactionManager(
212
257
override fun isRollbackOnly () = isRollback
213
258
214
259
override fun flush () {
215
- // Do noting
260
+ TransactionSynchronizationUtils .triggerFlush()
261
+ }
262
+ }
263
+
264
+ class ExposedConnectionHandle (
265
+ val transaction : Transaction
266
+ ) : ConnectionHandle {
267
+ override fun getConnection (): Connection {
268
+ return transaction.connection.connection as Connection
216
269
}
217
270
}
218
271
}
0 commit comments