Skip to content

Instantly share code, notes, and snippets.

@matthias-springer
Last active October 11, 2023 13:53
Show Gist options
  • Save matthias-springer/b664feb23be0159f72726025923bb9ca to your computer and use it in GitHub Desktop.
Save matthias-springer/b664feb23be0159f72726025923bb9ca to your computer and use it in GitHub Desktop.
// RUN: mlir-opt %s -test-transform-dialect-interpreter="debug-payload-root-tag=payload" -test-transform-dialect-erase-schedule
#map = affine_map<(d0) -> (-d0 + 17, 5)>
#map1 = affine_map<(d0) -> (-d0 + 29, 7)>
module attributes {transform.target_tag="payload"} {
// Tiled batched matrix multiplication. %A and %B are the inputs, %C is the
// output.
func.func @test_matmul(%A: memref<1x17x19xf32>, %B: memref<1x19x29xf32>,
%C: memref<1x17x29xf32>) {
%c7 = arith.constant 7 : index
%c29 = arith.constant 29 : index
%c5 = arith.constant 5 : index
%c17 = arith.constant 17 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%A_tensor = bufferization.to_tensor %A restrict : memref<1x17x19xf32>
%B_tensor = bufferization.to_tensor %B restrict : memref<1x19x29xf32>
%C_tensor = bufferization.to_tensor %C restrict writable : memref<1x17x29xf32>
%3 = linalg.fill ins(%cst : f32) outs(%C_tensor : tensor<1x17x29xf32>) -> tensor<1x17x29xf32>
%4 = scf.for %arg3 = %c0 to %c17 step %c5 iter_args(%arg4 = %3) -> (tensor<1x17x29xf32>) {
%5 = affine.min #map(%arg3)
%6 = scf.for %arg5 = %c0 to %c29 step %c7 iter_args(%arg6 = %arg4) -> (tensor<1x17x29xf32>) {
%7 = affine.min #map1(%arg5)
// Let's assume that something went wrong during tiling. Due to a bug, we
// extract from %arg4 instead of %arg6.
%extracted_slice = tensor.extract_slice %arg4[0, %arg3, %arg5] [1, %5, %7] [1, 1, 1] : tensor<1x17x29xf32> to tensor<1x?x?xf32>
%extracted_slice_0 = tensor.extract_slice %A_tensor[0, %arg3, 0] [1, %5, 19] [1, 1, 1] : tensor<1x17x19xf32> to tensor<1x?x19xf32>
%extracted_slice_1 = tensor.extract_slice %B_tensor[0, 0, %arg5] [1, 19, %7] [1, 1, 1] : tensor<1x19x29xf32> to tensor<1x19x?xf32>
%8 = linalg.batch_matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x?x19xf32>, tensor<1x19x?xf32>) outs(%extracted_slice : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
%inserted_slice = tensor.insert_slice %8 into %arg6[0, %arg3, %arg5] [1, %5, %7] [1, 1, 1] : tensor<1x?x?xf32> into tensor<1x17x29xf32>
scf.yield %inserted_slice : tensor<1x17x29xf32>
}
scf.yield %6 : tensor<1x17x29xf32>
}
bufferization.materialize_in_destination %4 in restrict writable %C : (tensor<1x17x29xf32>, memref<1x17x29xf32>) -> ()
return
}
}
transform.sequence failures(propagate) {
^bb0(%module1: !transform.any_op):
// Tile matmul
// ...
%func2 = transform.structured.match ops{["func.func"]} in %module1
: (!transform.any_op) -> !transform.any_op
// Bufferize
%module2 = transform.bufferization.one_shot_bufferize %module1
{allow_return_allocs_from_loops = true}
: (!transform.any_op) -> !transform.any_op
// Canonicalize + CSE
transform.apply_patterns to %module2 {
transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
// Deallocate buffers. That's a no-op in this example because no temporary
// buffers are allocated.
%module3 = transform.apply_registered_pass
"buffer-deallocation-pipeline" to %module2
: (!transform.any_op) -> !transform.any_op
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment