## Problem

TensorFlow is one of the most popular machine learning libraries and most 
developers are used to train/inference models with TensorFlow/TensorFlow 
Serving. TVM is the flexible compiler to run computation efficiently in 
different devices. Although TensorFlow has implemented some efficient GPU 
operators, developers can benifit from TVM to get more than 10 times speedup 
and FPGA support. But TensorFlow and TVM have two different code stacks and 
runtime APIs to use.

There are two ways to integrated TVM with TensorFlow. The first one is 
tensorflow-to-tvm which has been support by relay importer. Most TensorFlow 
operators can be “translated” to TVM operators which is useful if want to run 
the TVM stack with the model structure from other frameworks.

The second one is tvm-to-tensorflow. This requires to embed TVM operators in 
TensorFlow graph so that we can use TensorFlow session to run preset operators 
and TVM-optimized operators. This is really helpful if we want to use TVM to 
optimize part of the computation graph while developers can use TensorFlow 
Python API to describe the model and use TensorFlow Serving for inference. 
Embedding TVM in TensorFlow requires the minimal cost to use TVM optimiztion on 
existing models and extend TensorFlow functionalities such as FPGA support.

This RFC describes how we design to support tvm-to-tensorflow with TensorFlow 
custom op API and the detail of implementation.

## Considerations

Developers can use TVM stack to build operators without limitation.

Developers can use TVM Python package to import and load TVM operators in 
TensorFlow graph.

Developers can specify the output_shape/output_dtype/target_device/memory_align 
for TVM operators.

## Proposal

The best way to extends TensorFlow functionality is building the TensorFlow 
custom op for TVM runtime. We build the operator called `TVMDSOOp` and it has 
implemented CPU and GPU kernels to load any TVM dynamic library. We can run 
TensorFlow graph with this op which invokes TVM inference with zero-copy Tensor 
data. Here is the walk-through examples.

Developer can implement the TVM operators with TVM Python API. All they need to 
do is exporting the dynamic libraries to local file system.

```
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
s = tvm.create_schedule(B.op)
fadd_dylib = tvm.build(s, [A, B], "llvm", name="addone")
fadd_dylib.export_library("tvm_addone_dll.so")

bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_dylib = tvm.build(s, [A, B], "cuda", name="addone")
fadd_dylib.export_library("tvm_addone_cuda_dll.so")
```

With the code in our pull-request, we will set `set(USE_TFOP ON)` and use CMake 
to build the TVM from scratch. It would generate the `tvm_dso_op.so` file and 
provide the `tvm.contrib.tf_op` in Python API. Then we can use TensorFlow and 
TVM to build the graph with TVM operators and run by TensorFlow session.

```
import tensorflow as tf
from tvm.contrib import tf_op

def main():
  mod = tf_op.Module("tvm_addone_dll.so")
  addone = mod.func("addone", output_shape=[2])

  with tf.Session() as sess:
    with tf.device("/cpu:0"):
      placeholder = tf.placeholder("float32", shape=[2])
      print(sess.run(addone(placeholder), feed_dict={placeholder: [1.0, 2.0]}))

    with tf.device("/gpu:0"):
      placeholder = tf.placeholder("float32")
      addone_gpu = tf_op.Module("tvm_addone_cuda_dll.so")["addone"]
      print(sess.run(addone_gpu(placeholder), feed_dict={placeholder: [1.0, 
2.0]}))

if __name__ == "__main__":
  main()
```

Since every TensorFlow custom op should has specified input tensors, we wrap 
TVM Python API to support operators with up to 8 input tensors. Users can pass 
multiple TensorFlow tensors to TVMDSOOp if we support multiple inputs in TVM 
operators. The Python API looks the same as single input.

```
import tensorflow as tf
from tvm.contrib import tf_op

def main():
  left = tf.placeholder("float32", shape=[4])
  right = tf.placeholder("float32", shape=[4])

  feed_dict = {
    left: [1.0, 2.0, 3.0, 4.0],
    right: [5.0, 6.0, 7.0, 8.0]
  }

  module = tf_op.Module("tvm_add_dll.so")
  add = module.func("vector_add", output_shape=tf.shape(left), 
output_dtype="float")

  with tf.Session() as sess:
    with tf.device("/cpu:0"):
      print(sess.run(add(left, right), feed_dict))

if __name__ == "__main__":
  main()
```

For more examples, please refer to 
https://github.com/tobegit3hub/tftvm/tree/master/examples .

All the TVM operators can be embedded into TensorFlow graph with this 
`TVMDSOOp` and Python API. We don't need to copy data from TensorFlow(Tensor) 
to TVM(DLPack) with zero-copy therefore the performance should be great.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464

Reply via email to