Skip to content

Commit b9426c8

Browse files
jclynesvc-squareup-copybara
authored andcommitted
Adds additional operations for working with MDC in suspending
and non-suspending contexts GitOrigin-RevId: 2f1a75b27aa07cf9508a61ae9e10f60370aa8e3d
1 parent 57778b6 commit b9426c8

File tree

12 files changed

+173
-105
lines changed

12 files changed

+173
-105
lines changed

gradle/libs.versions.toml

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ kotlinStdLibJdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8", version
151151
kotlinTest = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
152152
kotlinxHtml = { module = "org.jetbrains.kotlinx:kotlinx-html-jvm", version = "0.12.0" }
153153
kotlinxCoroutinesCore = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version = "1.10.1" }
154+
kotlinxCoroutinesSlf4j = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-slf4j", version = "1.10.1" }
154155
kotlinxCoroutinesTest = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version = "1.10.1" }
155156
kubernetesClient = { module = "io.kubernetes:client-java", version = "18.0.1" }
156157
kubernetesClientApi = { module = "io.kubernetes:client-java-api", version = "18.0.1" }

misk-api/api/misk-api.api

+2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ public final class misk/client/NetworkInterceptorWrapper : okhttp3/Interceptor {
5050
public abstract interface class misk/logging/Mdc {
5151
public abstract fun clear ()V
5252
public abstract fun get (Ljava/lang/String;)Ljava/lang/String;
53+
public abstract fun getCopyOfContextMap ()Ljava/util/Map;
5354
public abstract fun put (Ljava/lang/String;Ljava/lang/String;)V
55+
public abstract fun setContextMap (Ljava/util/Map;)V
5456
}
5557

5658
public abstract interface class misk/scope/ActionScoped {
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
package misk.logging
22

3+
typealias MdcContextMap = Map<String, String>
4+
35
interface Mdc {
46
fun put(key: String, value: String?)
57

68
fun get(key: String): String?
79

810
fun clear()
11+
12+
fun setContextMap(context: MdcContextMap)
13+
14+
fun getCopyOfContextMap(): MdcContextMap?
15+
916
}
1017

misk/api/misk.api

+12
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,19 @@ public final class misk/logging/MiskMdc : misk/logging/Mdc {
835835
public static final field INSTANCE Lmisk/logging/MiskMdc;
836836
public fun clear ()V
837837
public fun get (Ljava/lang/String;)Ljava/lang/String;
838+
public fun getCopyOfContextMap ()Ljava/util/Map;
838839
public fun put (Ljava/lang/String;Ljava/lang/String;)V
840+
public fun setContextMap (Ljava/util/Map;)V
841+
}
842+
843+
public final class misk/logging/ScopedMdcKt {
844+
public static final fun withMdc (Lmisk/logging/Mdc;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V
845+
public static final fun withMdc (Lmisk/logging/Mdc;[Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)V
846+
}
847+
848+
public final class misk/logging/coroutines/ScopedMdcKt {
849+
public static final fun withMdc (Lmisk/logging/Mdc;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
850+
public static final fun withMdc (Lmisk/logging/Mdc;[Lkotlin/Pair;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
839851
}
840852

841853
public final class misk/monitoring/JvmMetrics {

misk/build.gradle.kts

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ dependencies {
5454
implementation(libs.kotlinReflect)
5555
implementation(libs.kotlinStdLibJdk8)
5656
implementation(libs.kotlinxCoroutinesCore)
57+
implementation(libs.kotlinxCoroutinesSlf4j)
5758
implementation(libs.moshiAdapters)
5859
implementation(libs.okio)
5960
implementation(libs.openTracingConcurrent)
@@ -81,6 +82,7 @@ dependencies {
8182
testImplementation(libs.junitParams)
8283
testImplementation(libs.kotestAssertions)
8384
testImplementation(libs.kotlinTest)
85+
testImplementation(libs.kotlinxCoroutinesCore)
8486
testImplementation(libs.kotlinxCoroutinesTest)
8587
testImplementation(libs.logbackClassic)
8688
testImplementation(libs.okHttpMockWebServer)

misk/src/main/kotlin/misk/logging/DynamicMdcContext.kt

-55
This file was deleted.

misk/src/main/kotlin/misk/logging/MiskMdc.kt

+10
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,14 @@ object MiskMdc : Mdc {
1313
override fun clear() {
1414
MDC.clear()
1515
}
16+
17+
override fun setContextMap(context: MdcContextMap) {
18+
if (context.isNotEmpty()) {
19+
MDC.setContextMap(context)
20+
} else {
21+
MDC.clear()
22+
}
23+
}
24+
25+
override fun getCopyOfContextMap(): MdcContextMap? = MDC.getCopyOfContextMap()
1626
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package misk.logging
2+
3+
import org.slf4j.MDC
4+
5+
/**
6+
* Adds the given key, value pair to the MDC for the duration of the block.
7+
*/
8+
inline fun Mdc.withMdc(key: String, value: String, block: () -> Unit) =
9+
withMdc(key to value, block = block)
10+
11+
/**
12+
* Adds the given tags to the MDC for the duration of the block.
13+
*/
14+
inline fun Mdc.withMdc(vararg tags: Pair<String, String>, block: () -> Unit) {
15+
val oldState = getCopyOfContextMap()
16+
return try {
17+
tags.forEach { (key, value) -> put(key, value) }
18+
block()
19+
} finally {
20+
oldState?.let { setContextMap(it) } ?: clear()
21+
}
22+
}
23+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package misk.logging.coroutines
2+
3+
import kotlinx.coroutines.slf4j.MDCContext
4+
import kotlinx.coroutines.withContext
5+
import misk.logging.Mdc
6+
import mu.KotlinLogging
7+
import wisp.logging.getLogger
8+
import kotlin.coroutines.coroutineContext
9+
10+
/**
11+
* Adds the given key, value pair to the MDC for the duration of the block.
12+
* This is coroutine safe, so the additions will be added to the coroutine context
13+
*/
14+
suspend inline fun Mdc.withMdc(key: String, value: String, crossinline block: suspend () -> Unit) =
15+
withMdc(key to value, block = block)
16+
17+
/**
18+
* Adds the given tags to the MDC for the duration of the block.
19+
* This is coroutine safe, so the additions will be added to the coroutine context
20+
*/
21+
suspend inline fun Mdc.withMdc(
22+
vararg tags: Pair<String, String>,
23+
crossinline block: suspend () -> Unit
24+
) {
25+
if(coroutineContext[MDCContext] == null) {
26+
KotlinLogging.logger("misk.logging.coroutines.ScopedMdc").warn {
27+
"MDCContext is not present in the coroutine context, this is required to restore the previous MDC state"
28+
}
29+
}
30+
tags.forEach { (key, value) -> put(key, value) }
31+
return withContext(MDCContext()) {
32+
block()
33+
}
34+
}
35+

misk/src/main/kotlin/misk/web/actions/WebActions.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ package misk.web.actions
22

33
import kotlinx.coroutines.launch
44
import kotlinx.coroutines.runBlocking
5+
import kotlinx.coroutines.slf4j.MDCContext
56
import misk.ApplicationInterceptor
67
import misk.Chain
78
import misk.grpc.GrpcMessageSinkChannel
89
import misk.grpc.GrpcMessageSourceChannel
9-
import misk.logging.DynamicMdcContext
1010
import misk.scope.ActionScope
1111
import misk.web.HttpCall
1212
import misk.web.RealChain
@@ -40,7 +40,7 @@ internal fun WebAction.asChain(
4040
} else {
4141
// Handle suspending invocation, this includes building out the context to propagate MDC
4242
// and action scope.
43-
val context = DynamicMdcContext() +
43+
val context = MDCContext() +
4444
if (scope.inScope()) {
4545
scope.asContextElement()
4646
} else {

misk/src/test/kotlin/misk/logging/DynamicMdcContextTest.kt

-48
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package misk.logging
2+
3+
import jakarta.inject.Inject
4+
import kotlinx.coroutines.delay
5+
import kotlinx.coroutines.slf4j.MDCContext
6+
import kotlinx.coroutines.test.runTest
7+
import misk.MiskTestingServiceModule
8+
import misk.testing.MiskTest
9+
import misk.testing.MiskTestModule
10+
import org.junit.jupiter.api.Test
11+
import kotlin.test.assertEquals
12+
import kotlin.test.assertNull
13+
import misk.logging.coroutines.withMdc as withMdcCoroutines
14+
15+
@MiskTest(startService = false)
16+
internal class ScopedMdcKtTest {
17+
@MiskTestModule
18+
val module = MiskTestingServiceModule()
19+
20+
@Inject
21+
lateinit var mdc: Mdc
22+
23+
@Test
24+
fun `test withMdc in a coroutine for key value pairs`() = runTest(MDCContext()) {
25+
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
26+
mdc.withMdcCoroutines(*tags) {
27+
tags.assertTags()
28+
delay(100)
29+
tags.assertTags()
30+
}
31+
tags.asserMissingTags()
32+
}
33+
34+
@Test
35+
fun `test withMdc for key value pairs`() {
36+
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
37+
mdc.withMdc(*tags) {
38+
tags.assertTags()
39+
}
40+
tags.asserMissingTags()
41+
}
42+
43+
@Test
44+
fun `test withMdc in a coroutine for key value pair overrides`() = runTest(MDCContext()) {
45+
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
46+
mdc.withMdcCoroutines(*tags) {
47+
tags.assertTags()
48+
delay(100)
49+
tags.assertTags()
50+
val updatedTags = tags.map { if (it.first == "key1"){it.first to it.second+"00"} else {it} }.toTypedArray()
51+
mdc.withMdcCoroutines(*updatedTags) {
52+
updatedTags.assertTags()
53+
delay(100)
54+
updatedTags.assertTags()
55+
}
56+
tags.forEach { it.asserTag() }
57+
}
58+
tags.asserMissingTags()
59+
}
60+
61+
@Test
62+
fun `test withMdc for key value pair overrides`() {
63+
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
64+
mdc.withMdc(*tags) {
65+
tags.assertTags()
66+
val updatedTags = tags.map { if (it.first == "key1"){it.first to it.second+"00"} else {it} }.toTypedArray()
67+
mdc.withMdc(*updatedTags) {
68+
updatedTags.assertTags()
69+
}
70+
tags.forEach { it.asserTag() }
71+
}
72+
tags.asserMissingTags()
73+
}
74+
75+
fun Pair<String, String>.asserTag() = assertEquals(second, mdc.get(first))
76+
fun Array<Pair<String, String>>.assertTags() = forEach { it.asserTag() }
77+
fun Pair<String, String>.asserMissingTag() = assertNull( mdc.get(first))
78+
fun Array<Pair<String, String>>.asserMissingTags() = forEach { it.asserMissingTag() }
79+
}

0 commit comments

Comments
 (0)