Skip to content

Instantly share code, notes, and snippets.

View antiagainst's full-sized avatar

Lei Zhang antiagainst

View GitHub Profile
@antiagainst
antiagainst / conversion.mlir
Last active August 2, 2024 01:46
matvec in triton
// -----// IR Dump Before ConvertTritonToTritonGPU (convert-triton-to-tritongpu) ('builtin.module' operation) //----- //
#loc = loc(unknown)
module {
tt.func public @matvec(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} {
%c2_i32 = arith.constant 2 : i32 loc(#loc)
%cst = arith.constant dense<0.000000e+00> : tensor<4x16xf32> loc(#loc)
%c1_i32 = arith.constant 1 : i32 loc(#loc)
%c0_i32 = arith.constant 0 : i32 loc(#loc)
%cst_0 = arith.constant dense<1024> : tensor<4x1xi32> loc(#loc)
%cst_1 = arith.constant dense<2048> : tensor<16x2048xi32> loc(#loc)
# Build: docker build -f triton-torch-hip.Dockerfile -t triton-torch-hip .
# Run: docker run -it --rm --device /dev/kfd --device /dev/dri triton-torch-hip
FROM ubuntu:22.04
ARG INSTALL_TORCH=TRUE
ARG TORCH_VERSION=2.4.0.dev20240530
ARG ROCM_VERSION=6.1.2
# Setup ROCm package signing key
# Build with `docker build . -t sdxl-repro --build-arg DOCKER_USERID=$(id -u) --build-arg DOCKER_GROUPID=$(id -g)`
# Run with `docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add $(getent group render | cut -d: -f3)
# --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/to/downloaded/sdxl/weights:/weights sdxl-repro`
# To benchmark inside docker: `./benchmark-txt2img.sh N /weights`
FROM rocm/dev-ubuntu-22.04
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
# Disable apt-key parse waring
# Use manylinux that ships with many Python versions for wheels. manylinux_2_28
# is AlmaLinux 8 based and is binary-compatible with Red Hat Enterprise Linux.
FROM quay.io/pypa/manylinux_2_28_x86_64@sha256:9042a22d33af2223ff7a3599f236aff1e4ffd07e1ed1ac93a58877638317515f
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
######## Setup Python #######
# Choe our default Python version
ENV PATH="/opt/python/cp311-cp311/bin:${PATH}"
hal.executable public @main$async_dispatch_205 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>, #iree_gpu.mfma_layout<F16_32x32x8_F32>], target_arch = "gfx942", ukernels = "none"}>) {
hal.executable.export public @main$async_dispatch_205_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_f16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 3, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 2>}>, workgroup_size = [256 :
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @conv_dispatch_1_conv_2d_nchw_fchw_2x8x16x16x8x3x3_f16() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x8x33x33xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8x8x3x3xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8xf16>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x8x16x16xf16>>
%workgroup_id_z = hal.interface.workgroup.id[2] : index
FROM ubuntu:22.04
SHELL ["/bin/bash", "-e", "-u", "-o", "pipefail", "-c"]
# Disable apt-key parse waring
ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
# Basic development environment
RUN apt-get update && apt-get install -y \
curl wget \
//
// Generated by LLVM NVPTX Back-End
//
.version 7.6
.target sm_80
.address_size 64
// .globl matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_default_dispatch_0_matmul_3456x1024x2048_f32
.extern .shared .align 16 .b8 __dynamic_shared_memory__[];
This file has been truncated, but you can view the full file.
// -----// IR Dump After TypePropagation (iree-codegen-type-propagation) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%5:1283 = stream.resource.pack slices({
[0, 3] = %c640,
[0, 3] = %c153664,
[1, 3] = %c640,
[2, 6] = %c1280,
[3, 5] = %c640,
[3, 5] = %c1327104,
[4, 7] = %c1280,
[5, 989] = %c11796480,
[6, 8] = %c20480,