Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save matthias-springer/5cc5b29c1bd727a272a78d71f1e6e19a to your computer and use it in GitHub Desktop.
Save matthias-springer/5cc5b29c1bd727a272a78d71f1e6e19a to your computer and use it in GitHub Desktop.
// RUN: mlir-opt %s -test-transform-dialect-interpreter="debug-payload-root-tag=payload"
#map = affine_map<(d0) -> (d0)>
module attributes {transform.target_tag="payload"} {
// TOSA element-wise addition. %A and %B are the inputs, %C is the output.
func.func @test_matmul(%arg0: memref<100xf32>, %arg1: memref<100xf32>, %arg2: memref<100xf32>) {
%0 = bufferization.to_tensor %arg0 restrict : memref<100xf32>
%1 = bufferization.to_tensor %arg1 restrict : memref<100xf32>
%2 = tensor.empty() : tensor<100xf32>
%c0 = arith.constant 0 : index
%c100 = arith.constant 100 : index
%c5 = arith.constant 5 : index
%3 = scf.for %arg3 = %c0 to %c100 step %c5 iter_args(%arg4 = %2) -> (tensor<100xf32>) {
%extracted_slice = tensor.extract_slice %0[%arg3] [5] [1] : tensor<100xf32> to tensor<5xf32>
%extracted_slice_0 = tensor.extract_slice %1[%arg3] [5] [1] : tensor<100xf32> to tensor<5xf32>
// Let's assume that something went wrong during tiling. Due to a bug, we
// extract from %2 instead of %arg4.
%extracted_slice_1 = tensor.extract_slice %2[%arg3] [5] [1] : tensor<100xf32> to tensor<5xf32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<5xf32>, tensor<5xf32>) outs(%extracted_slice_1 : tensor<5xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%5 = arith.addf %in, %in_2 : f32
linalg.yield %5 : f32
} -> tensor<5xf32>
%inserted_slice = tensor.insert_slice %4 into %arg4[%arg3] [5] [1] : tensor<5xf32> into tensor<100xf32>
scf.yield %inserted_slice : tensor<100xf32>
}
bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<100xf32>, memref<100xf32>) -> ()
return
}
}
transform.sequence failures(propagate) {
^bb0(%module1: !transform.any_op):
%func2 = transform.structured.match ops{["func.func"]} in %module1
: (!transform.any_op) -> !transform.any_op
// Canonicalize + CSE
transform.apply_patterns to %func2 {
transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
// Eliminate empty tensors
transform.bufferization.eliminate_empty_tensors %func2 : !transform.any_op
// Bufferize
%module2 = transform.bufferization.one_shot_bufferize %module1
: (!transform.any_op) -> !transform.any_op
// Canonicalize + CSE
transform.apply_patterns to %module2 {
transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
// Deallocate buffers. That's a no-op in this example because no temporary
// buffers are allocated.
%module3 = transform.apply_registered_pass
"buffer-deallocation-pipeline" to %module2
: (!transform.any_op) -> !transform.any_op
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment