# Integrate TVM optimization into TensorFlow with TVMDSOOp

## Introduction

In the next release of TVM(maybe 0.7), we have released the new feature called 
TVMDSOOp to integrate TVM optimization with TensorFlow.

TVMDSOOp is the general custom operator for TensorFlow which can run any TVM 
optimization on CPU or GPU. In other words, you can optimize your subgraph or 
implement new operators in TVM and embed them into TensorFlow graph easily. It 
is valuable to try TVM and replace part of the model for optimization if you 
still want to keep using the TensorFlow infrastructure like SavedModel or 
TensorFlow Serving.

## How To Use

Now you can use TVMDSOOp by compiling the latest code of TVM. Notice that 
TVMDSOOp was not enabled by default and you may `set(USE_TF_TVMDSOOP ON)` in 
`config.cmake`.

Follow the documentation of TVM in https://docs.tvm.ai/install/from_source.html 
to compile with `USE_TF_TVMDSOOP` and install TVM Python package.

Now you can use pure TVM APIs to implement the computation operators. The 
following example will export the library of TVM add operator on CPU.

```
import tvm
from tvm.contrib import tf_op

def export_cpu_add_so():
    n = tvm.te.var("n")
    ph_a = tvm.te.placeholder((n,), name='ph_a')
    ph_b = tvm.te.placeholder((n,), name='ph_b')
    ph_c = tvm.te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
    sched = tvm.te.create_schedule(ph_c.op)
    fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")

    lib_path = "tvm_cpu_add.so"
    fadd_dylib.export_library(lib_path)
```

With the latest TVM Python APIs,  we can load dynamic libraries easily and use 
them like normal TensorFlow operators.

```
import tensorflow as tf
from tvm.contrib import tf_op

def test_tvm_cpu_add_so():
    lib_path = "tvm_cpu_add.so"
    module = tf_op.OpModule(lib_path)
    tvm_add = module.func("vector_add", output_shape=[4], output_dtype="float")

    x = tf.constant([1.0, 2.0, 3.0, 4.0])
    y = tf.constant([1.0, 3.0, 5.0, 7.0])
    print(tvm_add(x, y).numpy())
```

In order to load the libraries of TVM including `libtvm_runtime` and 
`tvm_dso_op`, please install or add to `LD_LIBRARY_PATH` before running your 
script.

LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/of/incubator-tvm/build/ ./your_script.py

Now enjoy hacking for remixing TVM runtime and TensorFlow session.

## How It Works

The implementation of TVMDSOOp is straightforward and here is the overall 
architecture.

![](https://wiki.4paradigm.com/download/attachments/77194865/tvm_how_it_work.002.jpeg?version=1&modificationDate=1588923941325&api=v2
 "The 4th Platform > Integrate TVM optimization into TensorFlow with TVMDSOOp > 
tvm_how_it_work.002.jpeg")

Since it is kind of the TensorFlow custom operator, we need to implement the 
computation kernel and register as TensorFlow operators. This operator support 
list of tensors as input arguments and setting for the shape/dtype of output 
tensor. 

With TVM runtime APIs, we can load TVM dynamic libraries as Module and get the 
function which was registered by user's TVM Python script. Even though 
TensorFlow passes the Tensor(tensorflow::Tensor) to kernels and TVM runtime 
requires DLPack for inference, TVMDSOOp will automatically convert the data of 
tensors for users at the lowest cost. Users only need to optimize their TVM 
Python scripts and use the operators in TensorFlow graph without extra 
integration work.

For more detail and code of the implementation, please refer to the merged 
pull-request in https://github.com/apache/incubator-tvm/pull/4459 .





---
[Visit Topic](https://discuss.tvm.ai/t/add-the-document-for-tvmdsoop/6622/1) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/dd9f9874f5ac248b6f276541c94a2e2780e786250a19232fdc9870f9719ae411).

Reply via email to