Skip to content

Instantly share code, notes, and snippets.

@ekreutz
Last active August 4, 2024 15:01
Show Gist options
  • Save ekreutz/160070126d5e2261a939c4ddf6afb642 to your computer and use it in GitHub Desktop.
Save ekreutz/160070126d5e2261a939c4ddf6afb642 to your computer and use it in GitHub Desktop.
Dot-product and Multi-head attention implementation in Tensorflow 2

Dot-product and Multi-head attention

Dot-product and Multi-head attention from the paper "Attention is all you need" (2017). Implementation in modern Tensorflow 2 using the Keras API.

Example use of the implementations below:

batch_size = 10
n_vectors = 150
d_model = 512

query = tf.random.uniform((batch_size, n_vectors, d_model), dtype=tf.float32)
key = tf.random.uniform((batch_size, n_vectors, d_model), dtype=tf.float32)
value = tf.random.uniform((batch_size, n_vectors, d_model), dtype=tf.float32)

# Test dot product attention
dp_layer = DotProductAttention(use_scale=True)
x = dp_layer([query, key, value])
print(f"Output from dot product attention: {x.shape}")

# Test multi-head attention
mh_layer = MultiHeadAttention(h=8)
x = mh_layer([query, key, value])
print(f"Output from multi-head attention: {x.shape}")

Dot-product:

class DotProductAttention(keras.layers.Layer):
    def __init__(self, use_scale=True, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.use_scale = use_scale

    def build(self, input_shape):
        query_shape = input_shape[0]
        if self.use_scale:
            dim_k = tf.cast(query_shape[-1], tf.float32)
            self.scale = 1 / tf.sqrt(dim_k)
        else:
            self.scale = None

    def call(self, input):
        query, key, value = input
        score = tf.matmul(query, key, transpose_b=True)
        if self.scale is not None:
            score *= self.scale
        return tf.matmul(tf.nn.softmax(score), value)

Multi-head attention:

class MultiHeadAttention(keras.layers.Layer):
    def __init__(self, h=8, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.h = h

    def build(self, input_shape):
        query_shape, key_shape, value_shape = input_shape
        d_model = query_shape[-1]

        # Note: units can be anything, but this is what the paper does
        units = d_model // self.h

        self.layersQ = []
        for _ in range(self.h):
            layer = Dense(units, activation=None, use_bias=False)
            layer.build(query_shape)
            self.layersQ.append(layer)

        self.layersK = []
        for _ in range(self.h):
            layer = Dense(units, activation=None, use_bias=False)
            layer.build(key_shape)
            self.layersK.append(layer)

        self.layersV = []
        for _ in range(self.h):
            layer = Dense(units, activation=None, use_bias=False)
            layer.build(value_shape)
            self.layersV.append(layer)

        self.attention = DotProductAttention(True)

        self.out = Dense(d_model, activation=None, use_bias=False)
        self.out.build((query_shape[0], query_shape[1], self.h * units))

    def call(self, input):
        query, key, value = input

        q = [layer(query) for layer in self.layersQ]
        k = [layer(key) for layer in self.layersK]
        v = [layer(value) for layer in self.layersV]

        # Head is in multi-head, just like the paper
        head = [self.attention([q[i], k[i], v[i]]) for i in range(self.h)]

        out = self.out(tf.concat(head, -1))
        return out
@sadrahkm
Copy link

sadrahkm commented Aug 4, 2024

Hi Emil,
Thank you for providing this piece of code.
Just one quick question. When using a multihead attention like this, does it need to be trained on data? or we can use pure like this without any training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment