[apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2019-12-03 Thread tobe
## Problem

TensorFlow is one of the most popular machine learning libraries and most 
developers are used to train/inference models with TensorFlow/TensorFlow 
Serving. TVM is the flexible compiler to run computation efficiently in 
different devices. Although TensorFlow has implemented some efficient GPU 
operators, developers can benifit from TVM to get more than 10 times speedup 
and FPGA support. But TensorFlow and TVM have two different code stacks and 
runtime APIs to use.

There are two ways to integrated TVM with TensorFlow. The first one is 
tensorflow-to-tvm which has been support by relay importer. Most TensorFlow 
operators can be “translated” to TVM operators which is useful if want to run 
the TVM stack with the model structure from other frameworks.

The second one is tvm-to-tensorflow. This requires to embed TVM operators in 
TensorFlow graph so that we can use TensorFlow session to run preset operators 
and TVM-optimized operators. This is really helpful if we want to use TVM to 
optimize part of the computation graph while developers can use TensorFlow 
Python API to describe the model and use TensorFlow Serving for inference. 
Embedding TVM in TensorFlow requires the minimal cost to use TVM optimiztion on 
existing models and extend TensorFlow functionalities such as FPGA support.

This RFC describes how we design to support tvm-to-tensorflow with TensorFlow 
custom op API and the detail of implementation.

## Considerations

Developers can use TVM stack to build operators without limitation.

Developers can use TVM Python package to import and load TVM operators in 
TensorFlow graph.

Developers can specify the output_shape/output_dtype/target_device/memory_align 
for TVM operators.

## Proposal

The best way to extends TensorFlow functionality is building the TensorFlow 
custom op for TVM runtime. We build the operator called `TVMDSOOp` and it has 
implemented CPU and GPU kernels to load any TVM dynamic library. We can run 
TensorFlow graph with this op which invokes TVM inference with zero-copy Tensor 
data. Here is the walk-through examples.

Developer can implement the TVM operators with TVM Python API. All they need to 
do is exporting the dynamic libraries to local file system.

```
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
s = tvm.create_schedule(B.op)
fadd_dylib = tvm.build(s, [A, B], "llvm", name="addone")
fadd_dylib.export_library("tvm_addone_dll.so")

bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_dylib = tvm.build(s, [A, B], "cuda", name="addone")
fadd_dylib.export_library("tvm_addone_cuda_dll.so")
```

With the code in our pull-request, we will set `set(USE_TFOP ON)` and use CMake 
to build the TVM from scratch. It would generate the `tvm_dso_op.so` file and 
provide the `tvm.contrib.tf_op` in Python API. Then we can use TensorFlow and 
TVM to build the graph with TVM operators and run by TensorFlow session.

```
import tensorflow as tf
from tvm.contrib import tf_op

def main():
  mod = tf_op.Module("tvm_addone_dll.so")
  addone = mod.func("addone", output_shape=[2])

  with tf.Session() as sess:
with tf.device("/cpu:0"):
  placeholder = tf.placeholder("float32", shape=[2])
  print(sess.run(addone(placeholder), feed_dict={placeholder: [1.0, 2.0]}))

with tf.device("/gpu:0"):
  placeholder = tf.placeholder("float32")
  addone_gpu = tf_op.Module("tvm_addone_cuda_dll.so")["addone"]
  print(sess.run(addone_gpu(placeholder), feed_dict={placeholder: [1.0, 
2.0]}))

if __name__ == "__main__":
  main()
```

Since every TensorFlow custom op should has specified input tensors, we wrap 
TVM Python API to support operators with up to 8 input tensors. Users can pass 
multiple TensorFlow tensors to TVMDSOOp if we support multiple inputs in TVM 
operators. The Python API looks the same as single input.

```
import tensorflow as tf
from tvm.contrib import tf_op

def main():
  left = tf.placeholder("float32", shape=[4])
  right = tf.placeholder("float32", shape=[4])

  feed_dict = {
left: [1.0, 2.0, 3.0, 4.0],
right: [5.0, 6.0, 7.0, 8.0]
  }

  module = tf_op.Module("tvm_add_dll.so")
  add = module.func("vector_add", output_shape=tf.shape(left), 
output_dtype="float")

  with tf.Session() as sess:
with tf.device("/cpu:0"):
  print(sess.run(add(left, right), feed_dict))

if __name__ == "__main__":
  main()
```

For more examples, please refer to 
https://github.com/tobegit3hub/tftvm/tree/master/examples .

All the TVM operators can be embedded into TensorFlow graph with this 
`TVMDSOOp` and Python API. We don't need to copy data from TensorFlow(Tensor) 
to TVM(DLPack) with zero-copy therefore the performance should be great.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issu

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2019-12-03 Thread tobe
The implementation of this proposal has been submit to 
https://github.com/apache/incubator-tvm/pull/4459 .

Anyone can try to test their TVM operators by re-compiling TVM with 
`set(USE_TFOP ON)`.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#issuecomment-561510740

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2019-12-08 Thread tobe
Thanks @jwfromm and you're definitely right. This is the fastest way to 
integrate TVM functions into TensorFlow if we can not convert the whole model 
to TVM. This may be meaningful for TensorFlow developers if they can to try TVM 
and leverage the sub-graph optimizaition from TVM.

Actually this project is the TensorFlow custom op with TVM runtime. We 
originally develop in the standalone project 
https://github.com/tobegit3hub/tftvm . Since it depends on TVM and TensorFlow 
to compile, it is okay to be one of the TVM contrib libraries or maintain in 
the independent project.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#issuecomment-563093299

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2019-12-19 Thread tobe
Hi @tqchen @jwfromm @jroesch @soiferj , do you have any other comment?

We may add more docs about implementation and usage so that everyone can know 
it works.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#issuecomment-567819973

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2020-04-29 Thread tobe
The PR has been merged and we will close this issue.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#issuecomment-621593883

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2020-04-29 Thread tobe
Closed #4464.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#event-3287390519

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2020-08-13 Thread tobe
@652994331 You should not `tftvm` which is deprecated and please rebuild TVM 
with `USE_TF_TVMSOOP=ON`.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#issuecomment-673908609

Re: [apache/incubator-tvm] [RFC] Add TVMDSOOp to integrate any TVM operator with TensorFlow (#4464)

2020-08-14 Thread tobe
@652994331 You need to install `tensorflow` so that TVMDSOOp could link to 
TensorFlow libraries.

Here is the error message from your cmake.

```
ModuleNotFoundError: No module named 'tensorflow'
```

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-tvm/issues/4464#issuecomment-673935381

[TVM Discuss] [Development] Add the document for TVMDSOOp

2020-05-08 Thread tobe@4Paradigm via TVM Discuss


# Integrate TVM optimization into TensorFlow with TVMDSOOp

## Introduction

In the next release of TVM(maybe 0.7), we have released the new feature called 
TVMDSOOp to integrate TVM optimization with TensorFlow.

TVMDSOOp is the general custom operator for TensorFlow which can run any TVM 
optimization on CPU or GPU. In other words, you can optimize your subgraph or 
implement new operators in TVM and embed them into TensorFlow graph easily. It 
is valuable to try TVM and replace part of the model for optimization if you 
still want to keep using the TensorFlow infrastructure like SavedModel or 
TensorFlow Serving.

## How To Use

Now you can use TVMDSOOp by compiling the latest code of TVM. Notice that 
TVMDSOOp was not enabled by default and you may `set(USE_TF_TVMDSOOP ON)` in 
`config.cmake`.

Follow the documentation of TVM in https://docs.tvm.ai/install/from_source.html 
to compile with `USE_TF_TVMDSOOP` and install TVM Python package.

Now you can use pure TVM APIs to implement the computation operators. The 
following example will export the library of TVM add operator on CPU.

```
import tvm
from tvm.contrib import tf_op

def export_cpu_add_so():
n = tvm.te.var("n")
ph_a = tvm.te.placeholder((n,), name='ph_a')
ph_b = tvm.te.placeholder((n,), name='ph_b')
ph_c = tvm.te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
sched = tvm.te.create_schedule(ph_c.op)
fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")

lib_path = "tvm_cpu_add.so"
fadd_dylib.export_library(lib_path)
```

With the latest TVM Python APIs,  we can load dynamic libraries easily and use 
them like normal TensorFlow operators.

```
import tensorflow as tf
from tvm.contrib import tf_op

def test_tvm_cpu_add_so():
lib_path = "tvm_cpu_add.so"
module = tf_op.OpModule(lib_path)
tvm_add = module.func("vector_add", output_shape=[4], output_dtype="float")

x = tf.constant([1.0, 2.0, 3.0, 4.0])
y = tf.constant([1.0, 3.0, 5.0, 7.0])
print(tvm_add(x, y).numpy())
```

In order to load the libraries of TVM including `libtvm_runtime` and 
`tvm_dso_op`, please install or add to `LD_LIBRARY_PATH` before running your 
script.

LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/of/incubator-tvm/build/ ./your_script.py

Now enjoy hacking for remixing TVM runtime and TensorFlow session.

## How It Works

The implementation of TVMDSOOp is straightforward and here is the overall 
architecture.

![](https://wiki.4paradigm.com/download/attachments/77194865/tvm_how_it_work.002.jpeg?version=1&modificationDate=1588923941325&api=v2
 "The 4th Platform > Integrate TVM optimization into TensorFlow with TVMDSOOp > 
tvm_how_it_work.002.jpeg")

Since it is kind of the TensorFlow custom operator, we need to implement the 
computation kernel and register as TensorFlow operators. This operator support 
list of tensors as input arguments and setting for the shape/dtype of output 
tensor. 

With TVM runtime APIs, we can load TVM dynamic libraries as Module and get the 
function which was registered by user's TVM Python script. Even though 
TensorFlow passes the Tensor(tensorflow::Tensor) to kernels and TVM runtime 
requires DLPack for inference, TVMDSOOp will automatically convert the data of 
tensors for users at the lowest cost. Users only need to optimize their TVM 
Python scripts and use the operators in TensorFlow graph without extra 
integration work.

For more detail and code of the implementation, please refer to the merged 
pull-request in https://github.com/apache/incubator-tvm/pull/4459 .





---
[Visit Topic](https://discuss.tvm.ai/t/add-the-document-for-tvmdsoop/6622/1) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/dd9f9874f5ac248b6f276541c94a2e2780e786250a19232fdc9870f9719ae411).


[TVM Discuss] [Development] Add the document for TVMDSOOp

2020-05-08 Thread tobe@4Paradigm via TVM Discuss


We have add the document in discuss first and please help to review if you have 
time @tqchen @FrozenGene @zhiics @gmagogsfm . Hope to add in the official 
document when the content is ready.





---
[Visit Topic](https://discuss.tvm.ai/t/add-the-document-for-tvmdsoop/6622/2) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/de87d0d9ba7d2f7daec9a3860257ccf9b38218918f22f9003e49bfbc3c57cee9).


[TVM Discuss] [Development] Add the document for TVMDSOOp

2020-06-11 Thread tobe@4Paradigm via TVM Discuss


Thanks @zhiics and @FrozenGene . We have the Keras example with TVMDSOOp as 
well and we will update the document in google docs later which may help to 
review.





---
[Visit Topic](https://discuss.tvm.ai/t/add-the-document-for-tvmdsoop/6622/6) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/e635d4d45069bb72f1cc71aad3998251f4b36e75e8303521f1186b978345aeee).