Last active
July 30, 2021 12:50
-
-
Save sathyarr/58f5147d92f8b8168c5a0d0f8b245d2e to your computer and use it in GitHub Desktop.
Custom operation to replace py_func in google/seq2seq(beam_search.py#L90). Using this Custom operation helps to successfully export the model for Tensorflow Serving.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "tensorflow/core/framework/op.h" | |
#include "tensorflow/core/framework/shape_inference.h" | |
#include "tensorflow/core/framework/op_kernel.h" | |
using namespace tensorflow; | |
REGISTER_OP("GatherTreePyCustom") | |
.Input("values: int32") | |
.Input("parents: int32") | |
.Output("res: int32"); | |
class GatherTreePyCustomOp : public OpKernel { | |
public: | |
explicit GatherTreePyCustomOp(OpKernelConstruction* context) : OpKernel(context) {} | |
void Compute(OpKernelContext* context) override { | |
// Grab the input tensor | |
const Tensor& input_tensor_values = context->input(0); | |
auto input_values = input_tensor_values.matrix<int32>(); | |
// Grab the input tensor | |
const Tensor& input_tensor_parents = context->input(1); | |
auto input_parents = input_tensor_parents.matrix<int32>(); | |
const TensorShape& input_tensor_values_shape = input_tensor_values.shape(); | |
const TensorShape& input_tensor_parents_shape = input_tensor_parents.shape(); | |
int beam_length = input_tensor_values_shape.dim_size(0); | |
int num_beams = input_tensor_values_shape.dim_size(1); | |
// Create an output tensor | |
Tensor* output_tensor = NULL; | |
TensorShape output_shape({input_tensor_values_shape.dim_size(0), input_tensor_values_shape.dim_size(1)}); | |
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); | |
auto output = output_tensor->matrix<int32>(); | |
// res in python code | |
output.setZero(); | |
for(int i = 0; i < input_tensor_values_shape.dim_size(1); i++){ | |
output(input_tensor_values_shape.dim_size(0) - 1, i) = input_values(input_tensor_values_shape.dim_size(0) - 1, i); | |
} | |
for(int beam_id = 0; beam_id < num_beams; beam_id++){ | |
int parent = input_parents(input_tensor_parents_shape.dim_size(0) - 1, beam_id); | |
for(int level = beam_length - 2; level >= 0; level--){ | |
output(level, beam_id) = input_values(level, parent); | |
parent = input_parents(level, parent); | |
} | |
} | |
} | |
}; | |
REGISTER_KERNEL_BUILDER(Name("GatherTreePyCustom").Device(DEVICE_CPU), GatherTreePyCustomOp); |
@kunalgoyal9 It should be placed at an appropriate directory while building the tensorflow server from source
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@sathyarr Thanks for this, how did you use it in the original code?