Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created July 5, 2024 20:28
Show Gist options
  • Save ezyang/f59f52a1e5cfafe842bd5589656af7da to your computer and use it in GitHub Desktop.
Save ezyang/f59f52a1e5cfafe842bd5589656af7da to your computer and use it in GitHub Desktop.
# AOT ID: ['8_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
cpp_fused__softmax_add_mul_rsub_0 = async_compile.cpp_pybinding(['const bfloat16*', 'const float*', 'float*', 'float*', 'float*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const bfloat16* in_ptr0,
const float* in_ptr1,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
bfloat16* out_ptr3)
{
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(768L); x0+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x1 + (832L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp13);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x1 + (832L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1), 16);
auto tmp14 = out_ptr0[static_cast<long>(x0)];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
auto tmp15 = at::vec::Vectorized<float>(tmp14);
auto tmp16 = tmp13 - tmp15;
auto tmp17 = tmp16.exp();
tmp17.store(out_ptr1 + static_cast<long>(x1 + (832L*x0)));
tmp_acc0_vec = tmp_acc0_vec + tmp17;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<long>(x1 + (832L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<long>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
auto tmp4 = at::vec::convert<bfloat16>(tmp3);
tmp4.store(out_ptr3 + static_cast<long>(x1 + (832L*x0)), 16);
}
}
}
}
}
''')
cpp_fused__to_copy_cat_stack_1 = async_compile.cpp_pybinding(['const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const int32_t*', 'const bfloat16*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int32_t*', 'int64_t*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const int32_t* in_ptr0,
const int32_t* in_ptr1,
const int32_t* in_ptr2,
const int32_t* in_ptr3,
const int32_t* in_ptr4,
const int32_t* in_ptr5,
const int32_t* in_ptr6,
const int32_t* in_ptr7,
const int32_t* in_ptr8,
const int32_t* in_ptr9,
const int32_t* in_ptr10,
const int32_t* in_ptr11,
const int32_t* in_ptr12,
const bfloat16* in_ptr13,
int32_t* out_ptr0,
int32_t* out_ptr1,
int32_t* out_ptr2,
int32_t* out_ptr3,
int32_t* out_ptr4,
int32_t* out_ptr5,
int32_t* out_ptr6,
int32_t* out_ptr7,
int32_t* out_ptr8,
int32_t* out_ptr9,
int32_t* out_ptr10,
int32_t* out_ptr11,
int64_t* out_ptr12,
bfloat16* out_ptr13,
bfloat16* out_ptr14,
bfloat16* out_ptr15,
bfloat16* out_ptr16,
bfloat16* out_ptr17,
bfloat16* out_ptr18,
bfloat16* out_ptr19,
bfloat16* out_ptr20)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr0 + static_cast<long>(x0), 16);
tmp0.store(out_ptr0 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
out_ptr0[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr1 + static_cast<long>(x0), 16);
tmp0.store(out_ptr1 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr1[static_cast<long>(x0)];
out_ptr1[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr2 + static_cast<long>(x0), 16);
tmp0.store(out_ptr2 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr2[static_cast<long>(x0)];
out_ptr2[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr3 + static_cast<long>(x0), 16);
tmp0.store(out_ptr3 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr3[static_cast<long>(x0)];
out_ptr3[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr4 + static_cast<long>(x0), 16);
tmp0.store(out_ptr4 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr4[static_cast<long>(x0)];
out_ptr4[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr5 + static_cast<long>(x0), 16);
tmp0.store(out_ptr5 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr5[static_cast<long>(x0)];
out_ptr5[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr6 + static_cast<long>(x0), 16);
tmp0.store(out_ptr6 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr6[static_cast<long>(x0)];
out_ptr6[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr7 + static_cast<long>(x0), 16);
tmp0.store(out_ptr7 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr7[static_cast<long>(x0)];
out_ptr7[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr8 + static_cast<long>(x0), 16);
tmp0.store(out_ptr8 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr8[static_cast<long>(x0)];
out_ptr8[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr9 + static_cast<long>(x0), 16);
tmp0.store(out_ptr9 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr9[static_cast<long>(x0)];
out_ptr9[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr10 + static_cast<long>(x0), 16);
tmp0.store(out_ptr10 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr10[static_cast<long>(x0)];
out_ptr10[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr11 + static_cast<long>(x0), 16);
tmp0.store(out_ptr11 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(32L); x0<static_cast<long>(33L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr11[static_cast<long>(x0)];
out_ptr11[static_cast<long>(x0)] = tmp0;
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(384L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr12 + static_cast<long>(x0), 16);
auto tmp1 = at::vec::convert<int64_t,2,int32_t,1>(tmp0);
tmp1.store(out_ptr12 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(384L); x0<static_cast<long>(396L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr12[static_cast<long>(x0)];
auto tmp1 = c10::convert<int64_t>(tmp0);
out_ptr12[static_cast<long>(x0)] = tmp1;
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr13 + static_cast<long>(x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr13 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
tmp0.store(out_ptr14 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr13 + static_cast<long>(49152L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr15 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr13 + static_cast<long>(98304L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr16 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr13 + static_cast<long>(589824L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr17 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
tmp0.store(out_ptr18 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(192L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
auto tmp0 = out_ptr12[static_cast<long>((33L*x0) + (33L*(c10::div_floor_integer((x2 + (64L*x1)), 135168L))) + (c10::div_floor_integer((x2 + (64L*x1)), 4096L)))];
auto tmp13 = out_ptr12[static_cast<long>(30L + (33L*x0) + (33L*(c10::div_floor_integer((122880L + x2 + (64L*x1)), 135168L))) + (c10::div_floor_integer((x2 + (64L*x1)), 4096L)))];
auto tmp1 = (13L*x0) + (26L*(c10::div_floor_integer((x2 + (64L*x1)), 135168L)));
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = 156L;
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = decltype(tmp3)(tmp3 + tmp5);
auto tmp7 = tmp3 < 0;
auto tmp8 = tmp7 ? tmp6 : tmp3;
auto tmp9 = tmp8;
auto tmp10 = c10::convert<int64_t>(tmp9);
TORCH_CHECK((0 <= tmp10) & (tmp10 < 156L), "index out of bounds: 0 <= tmp10 < 156L");
auto tmp12 = in_ptr13[static_cast<long>(x2 + (64L*(static_cast<long>(c10::div_floor_integer(tmp8, 13L)) % static_cast<long>(12L))) + (768L*(static_cast<long>(x1) % static_cast<long>(64L))) + (49152L*(static_cast<long>(tmp8) % static_cast<long>(13L))))];
auto tmp14 = (13L*x0) + (13L*(c10::div_floor_integer((30L + (c10::div_floor_integer((x2 + (64L*x1)), 4096L))), 33L))) + (13L*(c10::div_floor_integer((122880L + x2 + (64L*x1)), 135168L)));
auto tmp15 = c10::convert<int64_t>(tmp14);
auto tmp16 = decltype(tmp13)(tmp13 + tmp15);
auto tmp17 = decltype(tmp16)(tmp16 + tmp5);
auto tmp18 = tmp16 < 0;
auto tmp19 = tmp18 ? tmp17 : tmp16;
auto tmp20 = tmp19;
auto tmp21 = c10::convert<int64_t>(tmp20);
TORCH_CHECK((0 <= tmp21) & (tmp21 < 156L), "index out of bounds: 0 <= tmp21 < 156L");
auto tmp23 = in_ptr13[static_cast<long>(x2 + (64L*(static_cast<long>(c10::div_floor_integer(tmp19, 13L)) % static_cast<long>(12L))) + (768L*(static_cast<long>(x1) % static_cast<long>(64L))) + (49152L*(static_cast<long>(tmp19) % static_cast<long>(13L))))];
out_ptr19[static_cast<long>(x2 + (64L*x1) + (28672L*x0))] = tmp12;
out_ptr20[static_cast<long>(x2 + (64L*x1) + (28672L*x0))] = tmp23;
}
}
}
}
}
}
''')
cpp_fused__softmax_add_cat_minimum_mul_new_ones_rsub_2 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const int64_t*', 'const bfloat16*', 'const float*', 'const float*', 'const bfloat16*', 'float*', 'float*', 'float*', 'float*', 'float*', 'float*', 'float*', 'float*', 'float*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const int64_t* in_ptr2,
const bfloat16* in_ptr3,
const float* in_ptr4,
const float* in_ptr5,
const bfloat16* in_ptr6,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3,
float* out_ptr4,
float* out_ptr5,
float* out_ptr6,
float* out_ptr7,
float* out_ptr8,
bfloat16* out_ptr9,
bfloat16* out_ptr10,
bfloat16* out_ptr11,
bfloat16* out_ptr12,
bfloat16* out_ptr13,
bfloat16* out_ptr14,
bfloat16* out_ptr15,
bfloat16* out_ptr16,
bfloat16* out_ptr17)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(192L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0), 16);
tmp0.store(out_ptr0 + static_cast<long>(x0));
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(768L + x0), 16);
tmp0.store(out_ptr1 + static_cast<long>(x0));
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(192L); x0+=static_cast<long>(16L))
{
auto tmp0 = static_cast<float>(1.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr2 + static_cast<long>(x0));
}
}
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(768L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(256L); x1+=static_cast<long>(16L))
{
auto tmp0 = static_cast<float>(1.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr3 + static_cast<long>(x1 + (448L*x0)));
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(192L); x2+=static_cast<long>(16L))
{
auto tmp0 = in_ptr1[static_cast<long>(64L + x1)];
auto tmp1 = in_ptr2[static_cast<long>((33L*x0) + (c10::div_floor_integer(x2, 64L)))];
auto tmp12 = in_ptr1[static_cast<long>(704L + x1)];
auto tmp13 = in_ptr2[static_cast<long>(30L + (33L*x0) + (c10::div_floor_integer(x2, 64L)))];
auto tmp2 = 13L;
auto tmp3 = c10::convert<int64_t>(tmp2);
auto tmp4 = decltype(tmp1)(tmp1 + tmp3);
auto tmp5 = tmp1 < 0;
auto tmp6 = tmp5 ? tmp4 : tmp1;
auto tmp7 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = static_cast<long>(tmp6);
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp7) & (tmp7 < at::vec::VectorizedN<int64_t,2>(13L)))).all_masked(), "index out of bounds: 0 <= tmp7 < 13L");
auto tmp9 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = in_ptr1[static_cast<long>((64L*tmp6) + (static_cast<long>((x2 + x2_inner)) % static_cast<long>(64L)))];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp10 = at::vec::Vectorized<float>(tmp0);
auto tmp11 = tmp10 * tmp9;
auto tmp14 = decltype(tmp13)(tmp13 + tmp3);
auto tmp15 = tmp13 < 0;
auto tmp16 = tmp15 ? tmp14 : tmp13;
auto tmp17 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = static_cast<long>(tmp16);
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp17) & (tmp17 < at::vec::VectorizedN<int64_t,2>(13L)))).all_masked(), "index out of bounds: 0 <= tmp17 < 13L");
auto tmp19 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = in_ptr1[static_cast<long>((64L*tmp16) + (static_cast<long>((x2 + x2_inner)) % static_cast<long>(64L)))];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp20 = at::vec::Vectorized<float>(tmp12);
auto tmp21 = tmp20 * tmp19;
tmp11.store(out_ptr4 + static_cast<long>(x2 + (448L*x1) + (28672L*x0)));
tmp21.store(out_ptr5 + static_cast<long>(x2 + (448L*x1) + (28672L*x0)));
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(768L); x0+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x1=static_cast<long>(0L); x1<static_cast<long>(448L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr3 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<long>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr5 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp8 = at::vec::minimum(tmp6, tmp7);
auto tmp9 = static_cast<float>(1.0);
auto tmp10 = at::vec::Vectorized<float>(tmp9);
auto tmp11 = tmp10 - tmp8;
auto tmp12 = static_cast<float>(-10000.0);
auto tmp13 = at::vec::Vectorized<float>(tmp12);
auto tmp14 = tmp11 * tmp13;
auto tmp15 = tmp5 + tmp14;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp15);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr6[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(long x1=static_cast<long>(0L); x1<static_cast<long>(448L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr3 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<long>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr5 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp16 = out_ptr6[static_cast<long>(x0)];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp8 = at::vec::minimum(tmp6, tmp7);
auto tmp9 = static_cast<float>(1.0);
auto tmp10 = at::vec::Vectorized<float>(tmp9);
auto tmp11 = tmp10 - tmp8;
auto tmp12 = static_cast<float>(-10000.0);
auto tmp13 = at::vec::Vectorized<float>(tmp12);
auto tmp14 = tmp11 * tmp13;
auto tmp15 = tmp5 + tmp14;
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = tmp15 - tmp17;
auto tmp19 = tmp18.exp();
tmp19.store(out_ptr7 + static_cast<long>(x1 + (448L*x0)));
tmp_acc0_vec = tmp_acc0_vec + tmp19;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr8[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
for(long x1=static_cast<long>(0L); x1<static_cast<long>(448L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr7 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp1 = out_ptr8[static_cast<long>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
auto tmp4 = at::vec::convert<bfloat16>(tmp3);
tmp4.store(out_ptr9 + static_cast<long>(x1 + (448L*x0)), 16);
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr6 + static_cast<long>(x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr10 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
tmp0.store(out_ptr11 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr6 + static_cast<long>(49152L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr12 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr6 + static_cast<long>(98304L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr13 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr6 + static_cast<long>(589824L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr14 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
tmp0.store(out_ptr15 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(192L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
auto tmp0 = in_ptr2[static_cast<long>((33L*x0) + (33L*(c10::div_floor_integer((x2 + (64L*x1)), 135168L))) + (c10::div_floor_integer((x2 + (64L*x1)), 4096L)))];
auto tmp13 = in_ptr2[static_cast<long>(30L + (33L*x0) + (33L*(c10::div_floor_integer((122880L + x2 + (64L*x1)), 135168L))) + (c10::div_floor_integer((x2 + (64L*x1)), 4096L)))];
auto tmp1 = (13L*x0) + (26L*(c10::div_floor_integer((x2 + (64L*x1)), 135168L)));
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = 156L;
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = decltype(tmp3)(tmp3 + tmp5);
auto tmp7 = tmp3 < 0;
auto tmp8 = tmp7 ? tmp6 : tmp3;
auto tmp9 = tmp8;
auto tmp10 = c10::convert<int64_t>(tmp9);
TORCH_CHECK((0 <= tmp10) & (tmp10 < 156L), "index out of bounds: 0 <= tmp10 < 156L");
auto tmp12 = in_ptr6[static_cast<long>(x2 + (64L*(static_cast<long>(c10::div_floor_integer(tmp8, 13L)) % static_cast<long>(12L))) + (768L*(static_cast<long>(x1) % static_cast<long>(64L))) + (49152L*(static_cast<long>(tmp8) % static_cast<long>(13L))))];
auto tmp14 = (13L*x0) + (13L*(c10::div_floor_integer((30L + (c10::div_floor_integer((x2 + (64L*x1)), 4096L))), 33L))) + (13L*(c10::div_floor_integer((122880L + x2 + (64L*x1)), 135168L)));
auto tmp15 = c10::convert<int64_t>(tmp14);
auto tmp16 = decltype(tmp13)(tmp13 + tmp15);
auto tmp17 = decltype(tmp16)(tmp16 + tmp5);
auto tmp18 = tmp16 < 0;
auto tmp19 = tmp18 ? tmp17 : tmp16;
auto tmp20 = tmp19;
auto tmp21 = c10::convert<int64_t>(tmp20);
TORCH_CHECK((0 <= tmp21) & (tmp21 < 156L), "index out of bounds: 0 <= tmp21 < 156L");
auto tmp23 = in_ptr6[static_cast<long>(x2 + (64L*(static_cast<long>(c10::div_floor_integer(tmp19, 13L)) % static_cast<long>(12L))) + (768L*(static_cast<long>(x1) % static_cast<long>(64L))) + (49152L*(static_cast<long>(tmp19) % static_cast<long>(13L))))];
out_ptr16[static_cast<long>(x2 + (64L*x1) + (28672L*x0))] = tmp12;
out_ptr17[static_cast<long>(x2 + (64L*x1) + (28672L*x0))] = tmp23;
}
}
}
}
}
}
''')
cpp_fused_cat_clone_3 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const bfloat16* in_ptr0,
const bfloat16* in_ptr1,
bfloat16* out_ptr0,
bfloat16* out_ptr1,
bfloat16* out_ptr2,
bfloat16* out_ptr3)
{
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(49152L + x3 + (64L*x0) + (768L*x2) + (49152L*x1)), 32);
tmp0.store(out_ptr0 + static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0)), 32);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(98304L + x3 + (64L*x0) + (768L*x2) + (49152L*x1)), 32);
tmp0.store(out_ptr1 + static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0)), 32);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(147456L + x3 + (64L*x0) + (768L*x2) + (49152L*x1)), 32);
tmp0.store(out_ptr2 + static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0)), 32);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(576L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<long>(98304L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr3 + static_cast<long>(x2 + (64L*x1) + (36864L*x0)), 32);
}
}
}
}
}
}
''')
cpp_fused_clone_4 = async_compile.cpp_pybinding(['const int64_t*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const int64_t* in_ptr0,
const bfloat16* in_ptr1,
const bfloat16* in_ptr2,
bfloat16* out_ptr0,
bfloat16* out_ptr1)
{
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(192L); x2+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(3L + (3L*x1) + (33L*x0) + (33L*(c10::div_floor_integer((12288L + x3 + (64L*x2) + (12288L*x1)), 135168L))) + (c10::div_floor_integer((x3 + (64L*x2)), 4096L)))];
auto tmp1 = (13L*x0) + (13L*(c10::div_floor_integer((3L + (3L*x1) + (c10::div_floor_integer((x3 + (64L*x2)), 4096L))), 33L))) + (13L*(c10::div_floor_integer((12288L + x3 + (64L*x2) + (12288L*x1)), 135168L)));
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = 156L;
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = decltype(tmp3)(tmp3 + tmp5);
auto tmp7 = tmp3 < 0;
auto tmp8 = tmp7 ? tmp6 : tmp3;
auto tmp9 = tmp8;
auto tmp10 = c10::convert<int64_t>(tmp9);
TORCH_CHECK((0 <= tmp10) & (tmp10 < 156L), "index out of bounds: 0 <= tmp10 < 156L");
auto tmp12 = in_ptr1[static_cast<long>(x3 + (64L*(static_cast<long>(c10::div_floor_integer(tmp8, 13L)) % static_cast<long>(12L))) + (768L*(static_cast<long>(x2) % static_cast<long>(64L))) + (49152L*(static_cast<long>(tmp8) % static_cast<long>(13L))))];
auto tmp13 = in_ptr2[static_cast<long>(x3 + (64L*(static_cast<long>(c10::div_floor_integer(tmp8, 13L)) % static_cast<long>(12L))) + (768L*(static_cast<long>(x2) % static_cast<long>(64L))) + (49152L*(static_cast<long>(tmp8) % static_cast<long>(13L))))];
out_ptr0[static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0))] = tmp12;
out_ptr1[static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0))] = tmp13;
}
}
}
}
}
}
}
''')
cpp_fused__softmax__to_copy_add_cat_mul_rsub_5 = async_compile.cpp_pybinding(['const bfloat16*', 'const float*', 'const bfloat16*', 'const float*', 'const bfloat16*', 'const float*', 'const int64_t*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'float*', 'float*', 'float*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const bfloat16* in_ptr0,
const float* in_ptr1,
const bfloat16* in_ptr2,
const float* in_ptr3,
const bfloat16* in_ptr4,
const float* in_ptr5,
const int64_t* in_ptr6,
const bfloat16* in_ptr7,
const bfloat16* in_ptr8,
const bfloat16* in_ptr9,
bfloat16* out_ptr0,
bfloat16* out_ptr1,
bfloat16* out_ptr2,
bfloat16* out_ptr3,
float* out_ptr4,
float* out_ptr5,
float* out_ptr6,
bfloat16* out_ptr7,
bfloat16* out_ptr8,
bfloat16* out_ptr9,
bfloat16* out_ptr10)
{
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(6912L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x1 + (64L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
auto tmp14 = at::vec::convert<bfloat16>(tmp13);
tmp14.store(out_ptr0 + static_cast<long>(x1 + (512L*x0)), 16);
}
}
}
{
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(576L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(192L); x2+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr2 + static_cast<long>(x2 + (192L*x1) + (110592L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(x2 + (192L*x1)), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
auto tmp14 = at::vec::convert<bfloat16>(tmp13);
tmp14.store(out_ptr1 + static_cast<long>(x2 + (512L*x1) + (294912L*x0)), 16);
}
}
}
}
{
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(192L); x3+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr4 + static_cast<long>(x3 + (192L*x2) + (12288L*x1) + (110592L*x0)), 16);
auto tmp6 = in_ptr5[static_cast<long>(128L + x2 + (64L*x1))];
auto tmp7 = in_ptr6[static_cast<long>(3L + (3L*x1) + (33L*x0) + (c10::div_floor_integer(x3, 64L)))];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp8 = 13L;
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = decltype(tmp7)(tmp7 + tmp9);
auto tmp11 = tmp7 < 0;
auto tmp12 = tmp11 ? tmp10 : tmp7;
auto tmp13 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x3_inner = 0; x3_inner < 16; x3_inner++)
{
tmpbuf[x3_inner] = static_cast<long>(tmp12);
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp13) & (tmp13 < at::vec::VectorizedN<int64_t,2>(13L)))).all_masked(), "index out of bounds: 0 <= tmp13 < 13L");
auto tmp15 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x3_inner = 0; x3_inner < 16; x3_inner++)
{
tmpbuf[x3_inner] = in_ptr5[static_cast<long>((64L*tmp12) + (static_cast<long>((x3 + x3_inner)) % static_cast<long>(64L)))];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp16 = at::vec::Vectorized<float>(tmp6);
auto tmp17 = tmp16 * tmp15;
auto tmp18 = static_cast<float>(1.0);
auto tmp19 = at::vec::Vectorized<float>(tmp18);
auto tmp20 = tmp19 - tmp17;
auto tmp21 = static_cast<float>(-10000.0);
auto tmp22 = at::vec::Vectorized<float>(tmp21);
auto tmp23 = tmp20 * tmp22;
auto tmp24 = tmp5 + tmp23;
auto tmp25 = at::vec::convert<bfloat16>(tmp24);
tmp25.store(out_ptr2 + static_cast<long>(x3 + (512L*x2) + (32768L*x1) + (294912L*x0)), 16);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(6912L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr7 + static_cast<long>(x1 + (64L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(768L + x1), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
auto tmp14 = at::vec::convert<bfloat16>(tmp13);
tmp14.store(out_ptr3 + static_cast<long>(x1 + (512L*x0)), 16);
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(6912L); x0+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x1=static_cast<long>(0L); x1<static_cast<long>(512L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr8 + static_cast<long>(x1 + (512L*x0)), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr4[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(long x1=static_cast<long>(0L); x1<static_cast<long>(512L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr8 + static_cast<long>(x1 + (512L*x0)), 16);
auto tmp2 = out_ptr4[static_cast<long>(x0)];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = tmp4.exp();
tmp5.store(out_ptr5 + static_cast<long>(x1 + (512L*x0)));
tmp_acc0_vec = tmp_acc0_vec + tmp5;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr6[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
for(long x1=static_cast<long>(0L); x1<static_cast<long>(512L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr5 + static_cast<long>(x1 + (512L*x0)), 16);
auto tmp1 = out_ptr6[static_cast<long>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
auto tmp4 = at::vec::convert<bfloat16>(tmp3);
tmp4.store(out_ptr7 + static_cast<long>(x1 + (512L*x0)), 16);
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr9 + static_cast<long>(49152L + x3 + (64L*x0) + (768L*x2) + (49152L*x1)), 32);
tmp0.store(out_ptr8 + static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0)), 32);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr9 + static_cast<long>(98304L + x3 + (64L*x0) + (768L*x2) + (49152L*x1)), 32);
tmp0.store(out_ptr9 + static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0)), 32);
}
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(9L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(64L); x3+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr9 + static_cast<long>(147456L + x3 + (64L*x0) + (768L*x2) + (49152L*x1)), 32);
tmp0.store(out_ptr10 + static_cast<long>(x3 + (64L*x2) + (12288L*x1) + (110592L*x0)), 32);
}
}
}
}
}
}
}
''')
cpp_fused_cat_6 = async_compile.cpp_pybinding(['const bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const bfloat16* in_ptr0,
bfloat16* out_ptr0,
bfloat16* out_ptr1)
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(491520L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr0 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(540672L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr1 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
''')
cpp_fused__softmax_add_cat_minimum_mul_rsub_7 = async_compile.cpp_pybinding(['const float*', 'const bfloat16*', 'const float*', 'const float*', 'const bfloat16*', 'float*', 'float*', 'float*', 'float*', 'float*', 'float*', 'float*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const float* in_ptr0,
const bfloat16* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const bfloat16* in_ptr4,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3,
float* out_ptr4,
float* out_ptr5,
float* out_ptr6,
bfloat16* out_ptr7,
bfloat16* out_ptr8,
bfloat16* out_ptr9)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0), 16);
tmp0.store(out_ptr0 + static_cast<long>(x0));
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(192L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(640L + x0), 16);
tmp0.store(out_ptr1 + static_cast<long>(x0));
}
}
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(192L); x0+=static_cast<long>(16L))
{
auto tmp0 = static_cast<float>(1.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr2 + static_cast<long>(x0));
}
}
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(768L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(256L); x1+=static_cast<long>(16L))
{
auto tmp0 = static_cast<float>(1.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr3 + static_cast<long>(x1 + (448L*x0)));
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(768L); x0+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x1=static_cast<long>(0L); x1<static_cast<long>(448L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp8 = at::vec::minimum(tmp6, tmp7);
auto tmp9 = static_cast<float>(1.0);
auto tmp10 = at::vec::Vectorized<float>(tmp9);
auto tmp11 = tmp10 - tmp8;
auto tmp12 = static_cast<float>(-10000.0);
auto tmp13 = at::vec::Vectorized<float>(tmp12);
auto tmp14 = tmp11 * tmp13;
auto tmp15 = tmp5 + tmp14;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp15);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr4[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(long x1=static_cast<long>(0L); x1<static_cast<long>(448L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp16 = out_ptr4[static_cast<long>(x0)];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp8 = at::vec::minimum(tmp6, tmp7);
auto tmp9 = static_cast<float>(1.0);
auto tmp10 = at::vec::Vectorized<float>(tmp9);
auto tmp11 = tmp10 - tmp8;
auto tmp12 = static_cast<float>(-10000.0);
auto tmp13 = at::vec::Vectorized<float>(tmp12);
auto tmp14 = tmp11 * tmp13;
auto tmp15 = tmp5 + tmp14;
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = tmp15 - tmp17;
auto tmp19 = tmp18.exp();
tmp19.store(out_ptr5 + static_cast<long>(x1 + (448L*x0)));
tmp_acc0_vec = tmp_acc0_vec + tmp19;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr6[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
for(long x1=static_cast<long>(0L); x1<static_cast<long>(448L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr5 + static_cast<long>(x1 + (448L*x0)), 16);
auto tmp1 = out_ptr6[static_cast<long>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
auto tmp4 = at::vec::convert<bfloat16>(tmp3);
tmp4.store(out_ptr7 + static_cast<long>(x1 + (448L*x0)), 16);
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr4 + static_cast<long>(491520L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr8 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(64L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr4 + static_cast<long>(540672L + x2 + (64L*x0) + (768L*x1)), 32);
tmp0.store(out_ptr9 + static_cast<long>(x2 + (64L*x1) + (28672L*x0)), 32);
}
}
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_rsub_8 = async_compile.cpp_pybinding(['const bfloat16*', 'const float*', 'float*', 'float*', 'float*', 'bfloat16*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const bfloat16* in_ptr0,
const float* in_ptr1,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
bfloat16* out_ptr3)
{
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(768L); x0+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x1 + (832L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp13);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x1 + (832L*x0)), 16);
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1), 16);
auto tmp14 = out_ptr0[static_cast<long>(x0)];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(0.125);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
auto tmp5 = (tmp4);
auto tmp7 = static_cast<float>(1.0);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp8 - tmp6;
auto tmp10 = static_cast<float>(-10000.0);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp9 * tmp11;
auto tmp13 = tmp5 + tmp12;
auto tmp15 = at::vec::Vectorized<float>(tmp14);
auto tmp16 = tmp13 - tmp15;
auto tmp17 = tmp16.exp();
tmp17.store(out_ptr1 + static_cast<long>(x1 + (832L*x0)));
tmp_acc0_vec = tmp_acc0_vec + tmp17;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<long>(x0)] = static_cast<float>(tmp_acc0);
}
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<long>(x1 + (832L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<long>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
auto tmp4 = at::vec::convert<bfloat16>(tmp3);
tmp4.store(out_ptr3 + static_cast<long>(x1 + (832L*x0)), 16);
}
}
}
}
}
''')
cpp_fused_cat_mul_9 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const float*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*', 'float*'], '''
#include "/tmp/torchinductor_ezyang/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C" void kernel(const bfloat16* in_ptr0,
const bfloat16* in_ptr1,
const bfloat16* in_ptr2,
const bfloat16* in_ptr3,
const bfloat16* in_ptr4,
const bfloat16* in_ptr5,
const bfloat16* in_ptr6,
const bfloat16* in_ptr7,
const bfloat16* in_ptr8,
const float* in_ptr9,
bfloat16* out_ptr0,
bfloat16* out_ptr1,
bfloat16* out_ptr2,
bfloat16* out_ptr3,
bfloat16* out_ptr4,
float* out_ptr5)
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(4096L); x1+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x1 + (4096L*x0)), 32);
tmp0.store(out_ptr0 + static_cast<long>(x1 + (53248L*x0)), 32);
}
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(4096L); x1+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<long>(x1 + (4096L*x0)), 32);
tmp0.store(out_ptr1 + static_cast<long>(x1 + (53248L*x0)), 32);
}
}
}
#pragma omp parallel num_threads(24)
{
int tid = omp_get_thread_num();
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(36864L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr2 + static_cast<long>(x1 + (36864L*x0)), 16);
auto tmp2 = at::vec::Vectorized<bfloat16>::loadu(in_ptr3 + static_cast<long>(x1 + (36864L*x0)), 16);
auto tmp5 = at::vec::Vectorized<bfloat16>::loadu(in_ptr4 + static_cast<long>(x1 + (36864L*x0)), 16);
auto tmp8 = at::vec::Vectorized<bfloat16>::loadu(in_ptr5 + static_cast<long>(x1 + (36864L*x0)), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp3 = at::vec::convert<float>(tmp2);
auto tmp4 = tmp1 + tmp3;
auto tmp6 = at::vec::convert<float>(tmp5);
auto tmp7 = tmp4 + tmp6;
auto tmp9 = at::vec::convert<float>(tmp8);
auto tmp10 = tmp7 + tmp9;
auto tmp11 = at::vec::convert<bfloat16>(tmp10);
tmp11.store(out_ptr2 + static_cast<long>(x1 + (53248L*x0)), 16);
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(4096L); x1+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr6 + static_cast<long>(x1 + (4096L*x0)), 32);
tmp0.store(out_ptr3 + static_cast<long>(x1 + (53248L*x0)), 32);
}
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(4096L); x1+=static_cast<long>(32L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr7 + static_cast<long>(x1 + (4096L*x0)), 32);
tmp0.store(out_ptr4 + static_cast<long>(x1 + (53248L*x0)), 32);
}
}
}
}
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(12L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(832L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(64L); x2+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr8 + static_cast<long>(x2 + (64L*x1) + (53248L*x0)), 16);
auto tmp2 = in_ptr9[static_cast<long>(x1)];
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 * tmp3;
tmp4.store(out_ptr5 + static_cast<long>(x2 + (64L*x1) + (53248L*x0)));
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1 = args
args.clear()
assert_size_stride(arg0_1, (11, 3), (3, 1))
assert_size_stride(arg1_1, (11, 3), (3, 1))
assert_size_stride(arg2_1, (11, 3), (3, 1))
assert_size_stride(arg3_1, (11, 3), (3, 1))
assert_size_stride(arg4_1, (11, 3), (3, 1))
assert_size_stride(arg5_1, (11, 3), (3, 1))
assert_size_stride(arg6_1, (11, 3), (3, 1))
assert_size_stride(arg7_1, (11, 3), (3, 1))
assert_size_stride(arg8_1, (11, 3), (3, 1))
assert_size_stride(arg9_1, (11, 3), (3, 1))
assert_size_stride(arg10_1, (11, 3), (3, 1))
assert_size_stride(arg11_1, (11, 3), (3, 1))
assert_size_stride(arg12_1, (1, 12, 832, 64), (638976, 64, 768, 1))
assert_size_stride(arg13_1, (1, 13, 64), (832, 64, 1))
assert_size_stride(arg14_1, (1, 12, 832, 64), (638976, 64, 768, 1))
assert_size_stride(arg15_1, (1, 12, 832, 64), (638976, 64, 768, 1))
assert_size_stride(arg16_1, (1, 1, 1, 832), (832, 832, 832, 1))
assert_size_stride(arg17_1, (1, 1, 9, 64, 192), (110592, 110592, 12288, 192, 1))
assert_size_stride(arg18_1, (1, 1, 832, 1), (832, 832, 1, 1))
buf0 = empty_strided_cpu((12, 64, 832), (53248, 832, 1), torch.bfloat16)
# Source Nodes: [bmm], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(arg12_1, (12, 64, 64), (64, 768, 1), 0), reinterpret_tensor(arg14_1, (12, 64, 832), (64, 1, 768), 0), out=buf0)
buf1 = empty_strided_cpu((1, 12, 64, 1), (768, 64, 1, 768), torch.float32)
buf2 = empty_strided_cpu((1, 12, 64, 832), (638976, 53248, 832, 1), torch.float32)
buf3 = empty_strided_cpu((1, 12, 64, 1), (768, 64, 1, 768), torch.float32)
buf4 = empty_strided_cpu((1, 12, 64, 832), (638976, 53248, 832, 1), torch.bfloat16)
cpp_fused__softmax_add_mul_rsub_0(buf0, arg16_1, buf1, buf2, buf3, buf4)
buf5 = empty_strided_cpu((12, 64, 64), (4096, 64, 1), torch.bfloat16)
# Source Nodes: [bmm_1], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf4, (12, 64, 832), (53248, 832, 1), 0), reinterpret_tensor(arg15_1, (12, 832, 64), (64, 768, 1), 0), out=buf5)
buf18 = empty_strided_cpu((132, 3), (3, 1), torch.int32)
buf6 = reinterpret_tensor(buf18, (11, 3), (3, 1), 0) # alias
buf7 = reinterpret_tensor(buf18, (11, 3), (3, 1), 33) # alias
buf8 = reinterpret_tensor(buf18, (11, 3), (3, 1), 66) # alias
buf9 = reinterpret_tensor(buf18, (11, 3), (3, 1), 99) # alias
buf10 = reinterpret_tensor(buf18, (11, 3), (3, 1), 132) # alias
buf11 = reinterpret_tensor(buf18, (11, 3), (3, 1), 165) # alias
buf12 = reinterpret_tensor(buf18, (11, 3), (3, 1), 198) # alias
buf13 = reinterpret_tensor(buf18, (11, 3), (3, 1), 231) # alias
buf14 = reinterpret_tensor(buf18, (11, 3), (3, 1), 264) # alias
buf15 = reinterpret_tensor(buf18, (11, 3), (3, 1), 297) # alias
buf16 = reinterpret_tensor(buf18, (11, 3), (3, 1), 330) # alias
buf17 = reinterpret_tensor(buf18, (11, 3), (3, 1), 363) # alias
buf19 = empty_strided_cpu((12, 11, 3), (33, 3, 1), torch.int64)
buf25 = empty_strided_cpu((1, 12, 448, 64), (344064, 28672, 64, 1), torch.bfloat16)
buf20 = reinterpret_tensor(buf25, (1, 12, 64, 64), (344064, 28672, 64, 1), 0) # alias
buf78 = empty_strided_cpu((1, 12, 448, 64), (344064, 28672, 64, 1), torch.bfloat16)
buf73 = reinterpret_tensor(buf78, (1, 12, 64, 64), (344064, 28672, 64, 1), 0) # alias
buf21 = reinterpret_tensor(buf25, (1, 12, 64, 64), (344064, 28672, 64, 1), 4096) # alias
buf22 = reinterpret_tensor(buf25, (1, 12, 64, 64), (344064, 28672, 64, 1), 8192) # alias
buf23 = reinterpret_tensor(buf25, (1, 12, 64, 64), (344064, 28672, 64, 1), 12288) # alias
buf76 = reinterpret_tensor(buf78, (1, 12, 64, 64), (344064, 28672, 64, 1), 12288) # alias
buf24 = reinterpret_tensor(buf25, (1, 12, 192, 64), (344064, 28672, 64, 1), 16384) # alias
buf77 = reinterpret_tensor(buf78, (1, 12, 192, 64), (344064, 28672, 64, 1), 16384) # alias
cpp_fused__to_copy_cat_stack_1(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, buf18, arg14_1, buf6, buf7, buf8, buf9, buf10, buf11, buf12, buf13, buf14, buf15, buf16, buf17, buf19, buf20, buf73, buf21, buf22, buf23, buf76, buf24, buf77)
del arg0_1
del arg10_1
del arg11_1
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del arg6_1
del arg7_1
del arg8_1
del arg9_1
del buf10
del buf11
del buf12
del buf13
del buf14
del buf15
del buf16
del buf17
del buf18
del buf20
del buf21
del buf22
del buf23
del buf24
del buf6
del buf7
del buf8
del buf9
buf26 = empty_strided_cpu((12, 64, 448), (28672, 448, 1), torch.bfloat16)
# Source Nodes: [bmm_2], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(arg12_1, (12, 64, 64), (64, 768, 1), 49152), reinterpret_tensor(buf25, (12, 64, 448), (28672, 1, 64), 0), out=buf26)
buf30 = empty_strided_cpu((1, 1, 1, 448), (448, 448, 448, 1), torch.float32)
buf27 = reinterpret_tensor(buf30, (1, 1, 1, 192), (448, 448, 448, 1), 0) # alias
buf28 = reinterpret_tensor(buf30, (1, 1, 1, 64), (448, 448, 448, 1), 192) # alias
buf29 = reinterpret_tensor(buf30, (1, 1, 1, 192), (448, 448, 448, 1), 256) # alias
buf33 = empty_strided_cpu((1, 12, 64, 448), (344064, 28672, 448, 1), torch.float32)
buf31 = reinterpret_tensor(buf33, (1, 12, 64, 256), (344064, 28672, 448, 1), 0) # alias
buf32 = reinterpret_tensor(buf33, (1, 12, 64, 192), (344064, 28672, 448, 1), 256) # alias
buf86 = empty_strided_cpu((1, 12, 64, 448), (344064, 28672, 448, 1), torch.float32)
buf85 = reinterpret_tensor(buf86, (1, 12, 64, 192), (344064, 28672, 448, 1), 256) # alias
buf34 = buf3; del buf3 # reuse
buf35 = empty_strided_cpu((1, 12, 64, 448), (344064, 28672, 448, 1), torch.float32)
buf36 = buf1; del buf1 # reuse
buf43 = reinterpret_tensor(buf25, (1, 12, 64, 448), (344064, 28672, 448, 1), 0); del buf25 # reuse
buf42 = empty_strided_cpu((1, 12, 448, 64), (344064, 28672, 64, 1), torch.bfloat16)
buf37 = reinterpret_tensor(buf42, (1, 12, 64, 64), (344064, 28672, 64, 1), 0) # alias
buf95 = empty_strided_cpu((1, 12, 448, 64), (344064, 28672, 64, 1), torch.bfloat16)
buf90 = reinterpret_tensor(buf95, (1, 12, 64, 64), (344064, 28672, 64, 1), 0) # alias
buf38 = reinterpret_tensor(buf42, (1, 12, 64, 64), (344064, 28672, 64, 1), 4096) # alias
buf39 = reinterpret_tensor(buf42, (1, 12, 64, 64), (344064, 28672, 64, 1), 8192) # alias
buf40 = reinterpret_tensor(buf42, (1, 12, 64, 64), (344064, 28672, 64, 1), 12288) # alias
buf93 = reinterpret_tensor(buf95, (1, 12, 64, 64), (344064, 28672, 64, 1), 12288) # alias
buf41 = reinterpret_tensor(buf42, (1, 12, 192, 64), (344064, 28672, 64, 1), 16384) # alias
buf94 = reinterpret_tensor(buf95, (1, 12, 192, 64), (344064, 28672, 64, 1), 16384) # alias
cpp_fused__softmax_add_cat_minimum_mul_new_ones_rsub_2(arg16_1, arg13_1, buf19, buf26, buf30, buf33, arg15_1, buf27, buf28, buf29, buf31, buf32, buf85, buf34, buf35, buf36, buf43, buf37, buf90, buf38, buf39, buf40, buf93, buf41, buf94)
del buf26
del buf27
del buf28
del buf29
del buf31
del buf32
del buf33
del buf37
del buf38
del buf39
del buf40
del buf41
buf44 = empty_strided_cpu((12, 64, 64), (4096, 64, 1), torch.bfloat16)
# Source Nodes: [bmm_3], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf43, (12, 64, 448), (28672, 448, 1), 0), reinterpret_tensor(buf42, (12, 448, 64), (28672, 64, 1), 0), out=buf44)
del buf42
buf45 = empty_strided_cpu((12, 576, 64), (36864, 64, 1), torch.bfloat16)
# Source Nodes: [first_band_product], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(arg12_1, (12, 576, 64), (64, 768, 1), 98304), reinterpret_tensor(arg14_1, (12, 64, 64), (64, 1, 768), 0), out=buf45)
buf49 = empty_strided_cpu((1, 12, 9, 192, 64), (1327104, 110592, 12288, 64, 1), torch.bfloat16)
buf46 = reinterpret_tensor(buf49, (1, 12, 9, 64, 64), (1327104, 110592, 12288, 64, 1), 0) # alias
buf47 = reinterpret_tensor(buf49, (1, 12, 9, 64, 64), (1327104, 110592, 12288, 64, 1), 4096) # alias
buf48 = reinterpret_tensor(buf49, (1, 12, 9, 64, 64), (1327104, 110592, 12288, 64, 1), 8192) # alias
buf50 = empty_strided_cpu((1, 12, 9, 64, 64), (442368, 36864, 4096, 64, 1), torch.bfloat16)
cpp_fused_cat_clone_3(arg14_1, arg12_1, buf46, buf47, buf48, buf50)
del buf46
del buf47
del buf48
buf51 = empty_strided_cpu((108, 64, 192), (12288, 192, 1), torch.bfloat16)
# Source Nodes: [bmm_4], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf50, (108, 64, 64), (4096, 64, 1), 0), reinterpret_tensor(buf49, (108, 64, 192), (12288, 1, 64), 0), out=buf51)
buf52 = buf49; del buf49 # reuse
buf69 = empty_strided_cpu((1, 12, 9, 192, 64), (1327104, 110592, 12288, 64, 1), torch.bfloat16)
cpp_fused_clone_4(buf19, arg14_1, arg15_1, buf52, buf69)
buf53 = empty_strided_cpu((108, 64, 192), (12288, 192, 1), torch.bfloat16)
# Source Nodes: [bmm_5], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf50, (108, 64, 64), (4096, 64, 1), 0), reinterpret_tensor(buf52, (108, 64, 192), (12288, 1, 64), 0), out=buf53)
buf54 = reinterpret_tensor(buf50, (12, 576, 64), (36864, 64, 1), 0); del buf50 # reuse
# Source Nodes: [last_band_product], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(arg12_1, (12, 576, 64), (64, 768, 1), 98304), reinterpret_tensor(arg14_1, (12, 64, 64), (64, 1, 768), 589824), out=buf54)
buf59 = empty_strided_cpu((1, 12, 9, 64, 512), (3538944, 294912, 32768, 512, 1), torch.bfloat16)
buf55 = reinterpret_tensor(buf59, (1, 12, 9, 64, 64), (3538944, 294912, 32768, 512, 1), 0) # alias
buf56 = reinterpret_tensor(buf59, (1, 12, 9, 64, 192), (3538944, 294912, 32768, 512, 1), 64) # alias
buf57 = reinterpret_tensor(buf59, (1, 12, 9, 64, 192), (3538944, 294912, 32768, 512, 1), 256) # alias
buf58 = reinterpret_tensor(buf59, (1, 12, 9, 64, 64), (3538944, 294912, 32768, 512, 1), 448) # alias
buf60 = empty_strided_cpu((1, 12, 9, 64, 1), (6912, 576, 64, 1, 6912), torch.float32)
buf61 = empty_strided_cpu((1, 12, 9, 64, 512), (3538944, 294912, 32768, 512, 1), torch.float32)
buf62 = empty_strided_cpu((1, 12, 9, 64, 1), (6912, 576, 64, 1, 6912), torch.float32)
buf67 = empty_strided_cpu((1, 12, 9, 64, 512), (3538944, 294912, 32768, 512, 1), torch.bfloat16)
buf66 = buf52; del buf52 # reuse
buf63 = reinterpret_tensor(buf66, (1, 12, 9, 64, 64), (1327104, 110592, 12288, 64, 1), 0) # alias
buf64 = reinterpret_tensor(buf66, (1, 12, 9, 64, 64), (1327104, 110592, 12288, 64, 1), 4096) # alias
buf65 = reinterpret_tensor(buf66, (1, 12, 9, 64, 64), (1327104, 110592, 12288, 64, 1), 8192) # alias
cpp_fused__softmax__to_copy_add_cat_mul_rsub_5(buf45, arg16_1, buf51, arg17_1, buf53, arg13_1, buf19, buf54, buf59, arg15_1, buf55, buf56, buf57, buf58, buf60, buf61, buf62, buf67, buf63, buf64, buf65)
del arg13_1
del arg17_1
del buf51
del buf53
del buf55
del buf56
del buf57
del buf58
del buf59
del buf60
del buf61
del buf62
del buf63
del buf64
del buf65
buf68 = reinterpret_tensor(buf54, (108, 64, 64), (4096, 64, 1), 0); del buf54 # reuse
# Source Nodes: [bmm_6], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf67, (108, 64, 192), (32768, 512, 1), 64), reinterpret_tensor(buf66, (108, 192, 64), (12288, 64, 1), 0), out=buf68)
del buf66
buf70 = reinterpret_tensor(buf45, (108, 64, 64), (4096, 64, 1), 0); del buf45 # reuse
# Source Nodes: [bmm_7], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf67, (108, 64, 192), (32768, 512, 1), 256), reinterpret_tensor(buf69, (108, 192, 64), (12288, 64, 1), 0), out=buf70)
del buf69
buf71 = empty_strided_cpu((12, 576, 64), (36864, 64, 1), torch.bfloat16)
# Source Nodes: [einsum_3], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf67, (12, 576, 64), (294912, 512, 1), 0), reinterpret_tensor(arg15_1, (12, 64, 64), (64, 768, 1), 0), out=buf71)
buf72 = empty_strided_cpu((12, 576, 64), (36864, 64, 1), torch.bfloat16)
# Source Nodes: [einsum_4], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf67, (12, 576, 64), (294912, 512, 1), 448), reinterpret_tensor(arg15_1, (12, 64, 64), (64, 768, 1), 589824), out=buf72)
del buf67
buf74 = reinterpret_tensor(buf78, (1, 12, 64, 64), (344064, 28672, 64, 1), 4096) # alias
buf75 = reinterpret_tensor(buf78, (1, 12, 64, 64), (344064, 28672, 64, 1), 8192) # alias
cpp_fused_cat_6(arg14_1, buf74, buf75)
del buf73
del buf74
del buf75
del buf76
del buf77
buf79 = reinterpret_tensor(buf43, (12, 64, 448), (28672, 448, 1), 0); del buf43 # reuse
# Source Nodes: [bmm_8], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(arg12_1, (12, 64, 64), (64, 768, 1), 540672), reinterpret_tensor(buf78, (12, 64, 448), (28672, 1, 64), 0), out=buf79)
buf83 = buf30; del buf30 # reuse
buf80 = reinterpret_tensor(buf83, (1, 1, 1, 64), (448, 448, 448, 1), 0) # alias
buf81 = reinterpret_tensor(buf83, (1, 1, 1, 192), (448, 448, 448, 1), 64) # alias
buf82 = reinterpret_tensor(buf83, (1, 1, 1, 192), (448, 448, 448, 1), 256) # alias
buf84 = reinterpret_tensor(buf86, (1, 12, 64, 256), (344064, 28672, 448, 1), 0) # alias
buf87 = buf36; del buf36 # reuse
buf88 = buf35; del buf35 # reuse
buf89 = buf34; del buf34 # reuse
buf96 = reinterpret_tensor(buf78, (1, 12, 64, 448), (344064, 28672, 448, 1), 0); del buf78 # reuse
buf91 = reinterpret_tensor(buf95, (1, 12, 64, 64), (344064, 28672, 64, 1), 4096) # alias
buf92 = reinterpret_tensor(buf95, (1, 12, 64, 64), (344064, 28672, 64, 1), 8192) # alias
cpp_fused__softmax_add_cat_minimum_mul_rsub_7(arg16_1, buf79, buf83, buf86, arg15_1, buf80, buf81, buf82, buf84, buf87, buf88, buf89, buf96, buf91, buf92)
del buf79
del buf80
del buf81
del buf82
del buf83
del buf84
del buf85
del buf86
del buf88
del buf90
del buf91
del buf92
del buf93
del buf94
buf97 = empty_strided_cpu((12, 64, 64), (4096, 64, 1), torch.bfloat16)
# Source Nodes: [bmm_9], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf96, (12, 64, 448), (28672, 448, 1), 0), reinterpret_tensor(buf95, (12, 448, 64), (28672, 64, 1), 0), out=buf97)
del buf95
del buf96
buf98 = reinterpret_tensor(buf4, (12, 64, 832), (53248, 832, 1), 0); del buf4 # reuse
# Source Nodes: [bmm_10], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(arg12_1, (12, 64, 64), (64, 768, 1), 589824), reinterpret_tensor(arg14_1, (12, 64, 832), (64, 1, 768), 0), out=buf98)
del arg12_1
del arg14_1
buf99 = buf89; del buf89 # reuse
buf100 = buf2; del buf2 # reuse
buf101 = buf87; del buf87 # reuse
buf102 = reinterpret_tensor(buf0, (1, 12, 64, 832), (638976, 53248, 832, 1), 0); del buf0 # reuse
cpp_fused__softmax_add_mul_rsub_8(buf98, arg16_1, buf99, buf100, buf101, buf102)
del arg16_1
del buf101
del buf98
del buf99
buf103 = empty_strided_cpu((12, 64, 64), (4096, 64, 1), torch.bfloat16)
# Source Nodes: [bmm_11], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf102, (12, 64, 832), (53248, 832, 1), 0), reinterpret_tensor(arg15_1, (12, 832, 64), (64, 768, 1), 0), out=buf103)
del arg15_1
buf109 = reinterpret_tensor(buf102, (1, 12, 13, 64, 64), (638976, 53248, 4096, 64, 1), 0); del buf102 # reuse
buf104 = reinterpret_tensor(buf109, (1, 12, 1, 64, 64), (638976, 53248, 4096, 64, 1), 0) # alias
buf105 = reinterpret_tensor(buf109, (1, 12, 1, 64, 64), (638976, 53248, 4096, 64, 1), 4096) # alias
buf106 = reinterpret_tensor(buf109, (1, 12, 9, 64, 64), (638976, 53248, 4096, 64, 1), 8192) # alias
buf107 = reinterpret_tensor(buf109, (1, 12, 1, 64, 64), (638976, 53248, 4096, 64, 1), 45056) # alias
buf108 = reinterpret_tensor(buf109, (1, 12, 1, 64, 64), (638976, 53248, 4096, 64, 1), 49152) # alias
buf110 = reinterpret_tensor(buf100, (1, 12, 832, 64), (638976, 53248, 64, 1), 0); del buf100 # reuse
cpp_fused_cat_mul_9(buf5, buf44, buf68, buf70, buf71, buf72, buf97, buf103, buf109, arg18_1, buf104, buf105, buf106, buf107, buf108, buf110)
del arg18_1
return (reinterpret_tensor(buf110, (1, 832, 12, 64), (638976, 64, 53248, 1), 0), reinterpret_tensor(buf19, (1, 12, 11, 3), (396, 33, 3, 1), 0), )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg1_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg2_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg3_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg4_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg5_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg6_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg7_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg8_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg9_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg10_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg11_1 = rand_strided((11, 3), (3, 1), device='cpu', dtype=torch.int32)
arg12_1 = rand_strided((1, 12, 832, 64), (638976, 64, 768, 1), device='cpu', dtype=torch.bfloat16)
arg13_1 = rand_strided((1, 13, 64), (832, 64, 1), device='cpu', dtype=torch.float32)
arg14_1 = rand_strided((1, 12, 832, 64), (638976, 64, 768, 1), device='cpu', dtype=torch.bfloat16)
arg15_1 = rand_strided((1, 12, 832, 64), (638976, 64, 768, 1), device='cpu', dtype=torch.bfloat16)
arg16_1 = rand_strided((1, 1, 1, 832), (832, 832, 832, 1), device='cpu', dtype=torch.float32)
arg17_1 = rand_strided((1, 1, 9, 64, 192), (110592, 110592, 12288, 192, 1), device='cpu', dtype=torch.float32)
arg18_1 = rand_strided((1, 1, 832, 1), (832, 832, 1, 1), device='cpu', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('hf_BigBird', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment