Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful metrics produce incorrect output in distributed context #21054

Open
drasmuss opened this issue Mar 17, 2025 · 0 comments
Open

Stateful metrics produce incorrect output in distributed context #21054

drasmuss opened this issue Mar 17, 2025 · 0 comments
Assignees
Labels

Comments

@drasmuss
Copy link
Contributor

Something about how Metric state variables are aggregated across replicas is behaving oddly. Here is a toy example that simply counts the number of inputs:

import tensorflow as tf

import keras

# import tf_keras as keras

n_replicas = 4

gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
    gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)


class CountInputs(keras.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        val = tf.shape(y_pred)[0]
        self.var.assign_add(val)

    def reset_state(self):
        self.var.assign(0)

    def result(self):
        return tf.cast(self.var, "float32")


batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    inp = keras.Input(shape=(1,))
    layer = keras.layers.Dense(10)
    model = keras.Model(inp, layer(inp))
    model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])
    model.evaluate(x, y, batch_size=batch_size)

In tf-keras this produces the expected output:

10/10 [==============================] - 1s 5ms/step - loss: 0.0000e+00 - count_inputs: 120.0000

But in Keras 3 this produces:

10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - count_inputs: 889696.3750 - loss: 0.0000e+00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants