Skip to content

Instantly share code, notes, and snippets.

@Jokeren
Created August 9, 2024 02:39
Show Gist options
  • Save Jokeren/483687e5bb4968f61a0564d35b06d724 to your computer and use it in GitHub Desktop.
Save Jokeren/483687e5bb4968f61a0564d35b06d724 to your computer and use it in GitHub Desktop.
mlirs
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16>) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #mma>
%c32_i32 = arith.constant 32 : i32
%cst_0 = arith.constant dense<256> : tensor<32x1xi32, #blocked>
%cst_1 = arith.constant dense<256> : tensor<32x1xi32, #blocked1>
%cst_2 = arith.constant dense<256> : tensor<256x1xi32, #blocked>
%c64_i32 = arith.constant 64 : i32
%c256_i32 = arith.constant 256 : i32
%c0_i32 = arith.constant 0 : i32
%cst_3 = arith.constant dense<1.000000e-03> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_4 = arith.constant dense<2.560000e+02> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c32_i32 : i32
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked>
%5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked>
%6 = arith.muli %5, %cst_0 : tensor<32x1xi32, #blocked>
%7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked>
%14 = arith.muli %13, %cst_2 : tensor<256x1xi32, #blocked>
%15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked>
%16 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%17 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked>
%18 = scf.for %arg7 = %c0_i32 to %c256_i32 step %c64_i32 iter_args(%arg8 = %cst) -> (tensor<32x256xf32, #mma>) : i32 {
%60 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #blocked>
%61 = arith.addi %6, %60 : tensor<32x1xi32, #blocked>
%62 = tt.broadcast %61 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked>
%63 = arith.addi %62, %11 : tensor<32x64xi32, #blocked>
%64 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked>
%65 = arith.addi %14, %64 : tensor<256x1xi32, #blocked>
%66 = tt.broadcast %65 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked>
%67 = arith.addi %66, %15 : tensor<256x64xi32, #blocked>
%68 = tt.addptr %16, %63 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
%69 = tt.load %68 : tensor<32x64x!tt.ptr<f16>, #blocked>
%70 = tt.addptr %17, %67 : tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<256x64xi32, #blocked>
%71 = tt.load %70 : tensor<256x64x!tt.ptr<f16>, #blocked>
%72 = triton_gpu.local_alloc %71 : (tensor<256x64xf16, #blocked>) -> !tt.memdesc<256x64xf16, #shared>
%73 = tt.trans %72 {order = array<i32: 1, 0>} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1>
%74 = triton_gpu.local_load %73 : !tt.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%75 = triton_gpu.convert_layout %69 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%76 = tt.dot %75, %74, %arg8 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma>
scf.yield %76 : tensor<32x256xf32, #mma>
}
%19 = arith.truncf %18 : tensor<32x256xf32, #mma> to tensor<32x256xf16, #mma>
%20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%22 = tt.expand_dims %20 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2>
%23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1>
%24 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<1x256x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %22 : tensor<1x256x!tt.ptr<f16>, #blocked2>, tensor<1x256xi32, #blocked2>
%26 = tt.load %25 : tensor<1x256x!tt.ptr<f16>, #blocked2>
%27 = triton_gpu.convert_layout %26 : tensor<1x256xf16, #blocked2> -> tensor<1x256xf16, #mma>
%28 = tt.broadcast %27 : tensor<1x256xf16, #mma> -> tensor<32x256xf16, #mma>
%29 = arith.addf %19, %28 : tensor<32x256xf16, #mma>
%30 = arith.extf %29 : tensor<32x256xf16, #mma> to tensor<32x256xf32, #mma>
%31 = arith.extf %29 : tensor<32x256xf16, #mma> to tensor<32x256xf32, #mma>
%32 = arith.extf %29 : tensor<32x256xf16, #mma> to tensor<32x256xf32, #mma>
%33 = "tt.reduce"(%30) <{axis = 1 : i32}> ({
^bb0(%arg7: f32, %arg8: f32):
%60 = arith.addf %arg7, %arg8 : f32
tt.reduce.return %60 : f32
}) : (tensor<32x256xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%34 = arith.divf %33, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%35 = arith.mulf %31, %31 : tensor<32x256xf32, #mma>
%36 = "tt.reduce"(%35) <{axis = 1 : i32}> ({
^bb0(%arg7: f32, %arg8: f32):
%60 = arith.addf %arg7, %arg8 : f32
tt.reduce.return %60 : f32
}) : (tensor<32x256xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%37 = arith.divf %36, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%38 = arith.mulf %34, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%39 = arith.subf %37, %38 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%40 = math.sqrt %39 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%41 = arith.addf %40, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%42 = tt.expand_dims %34 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xf32, #mma>
%43 = tt.expand_dims %41 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xf32, #mma>
%44 = tt.broadcast %42 : tensor<32x1xf32, #mma> -> tensor<32x256xf32, #mma>
%45 = arith.subf %32, %44 : tensor<32x256xf32, #mma>
%46 = tt.broadcast %43 : tensor<32x1xf32, #mma> -> tensor<32x256xf32, #mma>
%47 = arith.divf %45, %46 : tensor<32x256xf32, #mma>
%48 = arith.truncf %47 : tensor<32x256xf32, #mma> to tensor<32x256xf16, #mma>
%49 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%50 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
%51 = arith.muli %50, %cst_1 : tensor<32x1xi32, #blocked1>
%52 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1>
%53 = arith.addi %52, %51 : tensor<32x1xi32, #blocked1>
%54 = tt.broadcast %53 : tensor<32x1xi32, #blocked1> -> tensor<32x256xi32, #blocked1>
%55 = tt.broadcast %23 : tensor<1x256xi32, #blocked1> -> tensor<32x256xi32, #blocked1>
%56 = arith.addi %54, %55 : tensor<32x256xi32, #blocked1>
%57 = tt.splat %arg5 : !tt.ptr<f16> -> tensor<32x256x!tt.ptr<f16>, #blocked1>
%58 = tt.addptr %57, %56 : tensor<32x256x!tt.ptr<f16>, #blocked1>, tensor<32x256xi32, #blocked1>
%59 = triton_gpu.convert_layout %48 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1>
tt.store %58, %59 : tensor<32x256x!tt.ptr<f16>, #blocked1>
tt.return
}
}
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16>) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #mma>
%c32_i32 = arith.constant 32 : i32
%cst_0 = arith.constant dense<256> : tensor<32x1xi32, #blocked>
%cst_1 = arith.constant dense<256> : tensor<32x1xi32, #blocked1>
%cst_2 = arith.constant dense<256> : tensor<256x1xi32, #blocked>
%c64_i32 = arith.constant 64 : i32
%c256_i32 = arith.constant 256 : i32
%c0_i32 = arith.constant 0 : i32
%cst_3 = arith.constant dense<1.000000e-03> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%cst_4 = arith.constant dense<2.560000e+02> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c32_i32 : i32
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked>
%5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked>
%6 = arith.muli %5, %cst_0 : tensor<32x1xi32, #blocked>
%7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked>
%14 = arith.muli %13, %cst_2 : tensor<256x1xi32, #blocked>
%15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked>
%16 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%17 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked>
%18 = scf.for %arg7 = %c0_i32 to %c256_i32 step %c64_i32 iter_args(%arg8 = %cst) -> (tensor<32x256xf32, #mma>) : i32 {
%62 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #blocked>
%63 = arith.addi %6, %62 : tensor<32x1xi32, #blocked>
%64 = tt.broadcast %63 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked>
%65 = arith.addi %64, %11 : tensor<32x64xi32, #blocked>
%66 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked>
%67 = arith.addi %14, %66 : tensor<256x1xi32, #blocked>
%68 = tt.broadcast %67 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked>
%69 = arith.addi %68, %15 : tensor<256x64xi32, #blocked>
%70 = tt.addptr %16, %65 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
%71 = tt.load %70 : tensor<32x64x!tt.ptr<f16>, #blocked>
%72 = tt.addptr %17, %69 : tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<256x64xi32, #blocked>
%73 = tt.load %72 : tensor<256x64x!tt.ptr<f16>, #blocked>
%74 = triton_gpu.local_alloc %73 : (tensor<256x64xf16, #blocked>) -> !tt.memdesc<256x64xf16, #shared>
%75 = tt.trans %74 {order = array<i32: 1, 0>} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1>
%76 = triton_gpu.local_load %75 : !tt.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%77 = triton_gpu.convert_layout %71 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%78 = tt.dot %77, %76, %arg8 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma>
scf.yield %78 : tensor<32x256xf32, #mma>
}
%19 = arith.truncf %18 : tensor<32x256xf32, #mma> to tensor<32x256xf16, #mma>
%20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%22 = tt.expand_dims %20 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2>
%23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1>
%24 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<1x256x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %22 : tensor<1x256x!tt.ptr<f16>, #blocked2>, tensor<1x256xi32, #blocked2>
%26 = tt.load %25 : tensor<1x256x!tt.ptr<f16>, #blocked2>
%27 = triton_gpu.convert_layout %26 : tensor<1x256xf16, #blocked2> -> tensor<1x256xf16, #mma>
%28 = tt.broadcast %27 : tensor<1x256xf16, #mma> -> tensor<32x256xf16, #mma>
%29 = arith.addf %19, %28 : tensor<32x256xf16, #mma>
%30 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1>
%31 = arith.extf %30 : tensor<32x256xf16, #blocked1> to tensor<32x256xf32, #blocked1>
%32 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1>
%33 = arith.extf %32 : tensor<32x256xf16, #blocked1> to tensor<32x256xf32, #blocked1>
%34 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked1>
%35 = arith.extf %34 : tensor<32x256xf16, #blocked1> to tensor<32x256xf32, #blocked1>
%36 = "tt.reduce"(%31) <{axis = 1 : i32}> ({
^bb0(%arg7: f32, %arg8: f32):
%62 = arith.addf %arg7, %arg8 : f32
tt.reduce.return %62 : f32
}) : (tensor<32x256xf32, #blocked1>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%37 = arith.divf %36, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%38 = arith.mulf %33, %33 : tensor<32x256xf32, #blocked1>
%39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({
^bb0(%arg7: f32, %arg8: f32):
%62 = arith.addf %arg7, %arg8 : f32
tt.reduce.return %62 : f32
}) : (tensor<32x256xf32, #blocked1>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%40 = arith.divf %39, %cst_4 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%41 = arith.mulf %37, %37 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%42 = arith.subf %40, %41 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%43 = math.sqrt %42 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%44 = arith.addf %43, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%45 = tt.expand_dims %37 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xf32, #blocked1>
%46 = tt.expand_dims %44 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xf32, #blocked1>
%47 = tt.broadcast %45 : tensor<32x1xf32, #blocked1> -> tensor<32x256xf32, #blocked1>
%48 = arith.subf %35, %47 : tensor<32x256xf32, #blocked1>
%49 = tt.broadcast %46 : tensor<32x1xf32, #blocked1> -> tensor<32x256xf32, #blocked1>
%50 = arith.divf %48, %49 : tensor<32x256xf32, #blocked1>
%51 = arith.truncf %50 : tensor<32x256xf32, #blocked1> to tensor<32x256xf16, #blocked1>
%52 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
%54 = arith.muli %53, %cst_1 : tensor<32x1xi32, #blocked1>
%55 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1>
%56 = arith.addi %55, %54 : tensor<32x1xi32, #blocked1>
%57 = tt.broadcast %56 : tensor<32x1xi32, #blocked1> -> tensor<32x256xi32, #blocked1>
%58 = tt.broadcast %23 : tensor<1x256xi32, #blocked1> -> tensor<32x256xi32, #blocked1>
%59 = arith.addi %57, %58 : tensor<32x256xi32, #blocked1>
%60 = tt.splat %arg5 : !tt.ptr<f16> -> tensor<32x256x!tt.ptr<f16>, #blocked1>
%61 = tt.addptr %60, %59 : tensor<32x256x!tt.ptr<f16>, #blocked1>, tensor<32x256xi32, #blocked1>
tt.store %61, %51 : tensor<32x256x!tt.ptr<f16>, #blocked1>
tt.return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment