yaoyaoding commented on code in PR #73:
URL: https://github.com/apache/tvm-ffi/pull/73#discussion_r2389188518


##########
python/tvm_ffi/cpp/load_inline.py:
##########
@@ -562,4 +565,139 @@ def load_inline(  # noqa: PLR0912, PLR0915
         _build_ninja(str(build_dir))
         # Use appropriate extension based on platform
         ext = ".dll" if IS_WINDOWS else ".so"
-        return load_module(str((build_dir / f"{name}{ext}").resolve()))
+        return str((build_dir / f"{name}{ext}").resolve())
+
+
+
+def load_inline(  # noqa: PLR0912, PLR0915
+    name: str,
+    *,
+    cpp_sources: Sequence[str] | str | None = None,
+    cuda_sources: Sequence[str] | str | None = None,
+    functions: Mapping[str, str] | Sequence[str] | str | None = None,
+    extra_cflags: Sequence[str] | None = None,
+    extra_cuda_cflags: Sequence[str] | None = None,
+    extra_ldflags: Sequence[str] | None = None,
+    extra_include_paths: Sequence[str] | None = None,
+    build_directory: str | None = None,
+) -> Module:
+    """Compile, build and load a C++/CUDA module from inline source code.
+
+    This function compiles the given C++ and/or CUDA source code into a shared 
library. Both ``cpp_sources`` and
+    ``cuda_sources`` are compiled to an object file, and then linked together 
into a shared library. It's possible to only
+    provide cpp_sources or cuda_sources.
+
+    The ``functions`` parameter is used to specify which functions in the 
source code should be exported to the tvm ffi
+    module. It can be a mapping, a sequence, or a single string. When a 
mapping is given, the keys are the names of the
+    exported functions, and the values are docstrings for the functions. When 
a sequence of string is given, they are
+    the function names needed to be exported, and the docstrings are set to 
empty strings. A single function name can
+    also be given as a string, indicating that only one function is to be 
exported.
+
+    Extra compiler and linker flags can be provided via the ``extra_cflags``, 
``extra_cuda_cflags``, and ``extra_ldflags``
+    parameters. The default flags are generally sufficient for most use cases, 
but you may need to provide additional
+    flags for your specific use case.
+
+    The include dir of tvm ffi and dlpack are used by default for the compiler 
to find the headers. Thus, you can
+    include any header from tvm ffi in your source code. You can also provide 
additional include paths via the
+    ``extra_include_paths`` parameter and include custom headers in your 
source code.
+
+    The compiled shared library is cached in a cache directory to avoid 
recompilation. The `build_directory` parameter
+    is provided to specify the build directory. If not specified, a default 
tvm ffi cache directory will be used.
+    The default cache directory can be specified via the `TVM_FFI_CACHE_DIR` 
environment variable. If not specified,
+    the default cache directory is ``~/.cache/tvm-ffi``.
+
+    Parameters
+    ----------
+    name: str
+        The name of the tvm ffi module.
+    cpp_sources: Sequence[str] | str, optional
+        The C++ source code. It can be a list of sources or a single source.
+    cuda_sources: Sequence[str] | str, optional
+        The CUDA source code. It can be a list of sources or a single source.
+    functions: Mapping[str, str] | Sequence[str] | str, optional
+        The functions in cpp_sources or cuda_source that will be exported to 
the tvm ffi module. When a mapping is
+        given, the keys are the names of the exported functions, and the 
values are docstrings for the functions. When
+        a sequence or a single string is given, they are the functions needed 
to be exported, and the docstrings are set
+        to empty strings. A single function name can also be given as a 
string. When cpp_sources is given, the functions
+        must be declared (not necessarily defined) in the cpp_sources. When 
cpp_sources is not given, the functions
+        must be defined in the cuda_sources. If not specified, no function 
will be exported.
+    extra_cflags: Sequence[str], optional
+        The extra compiler flags for C++ compilation.
+        The default flags are:
+
+        - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
+        - On Windows: ['/std:c++17', '/O2']
+
+    extra_cuda_cflags: Sequence[str], optional
+        The extra compiler flags for CUDA compilation.
+
+    extra_ldflags: Sequence[str], optional
+        The extra linker flags.
+        The default flags are:
+
+        - On Linux/macOS: ['-shared']
+        - On Windows: ['/DLL']
+
+    extra_include_paths: Sequence[str], optional
+        The extra include paths.
+
+    build_directory: str, optional
+        The build directory. If not specified, a default tvm ffi cache 
directory will be used. By default, the
+        cache directory is ``~/.cache/tvm-ffi``. You can also set the 
``TVM_FFI_CACHE_DIR`` environment variable to
+        specify the cache directory.
+
+    Returns
+    -------
+    mod: Module
+        The loaded tvm ffi module.
+
+
+    Example
+    -------
+
+    .. code-block:: python
+
+        import torch
+        from tvm_ffi import Module
+        import tvm_ffi.cpp
+
+        # define the cpp source code
+        cpp_source = '''
+             void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+               // implementation of a library function
+               TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+               DLDataType f32_dtype{kDLFloat, 32, 1};
+               TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
+               TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+               TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
+               TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must 
have the same shape";
+               for (int i = 0; i < x->shape[0]; ++i) {
+                 static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+               }
+             }
+        '''
+
+        # compile the cpp source code and load the module
+        mod: Module = tvm_ffi.cpp.load_inline(
+            name='hello',
+            cpp_sources=cpp_source,
+            functions='add_one_cpu'
+        )
+
+        # use the function from the loaded module to perform
+        x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+        y = torch.empty_like(x)
+        mod.add_one_cpu(x, y)
+        torch.testing.assert_close(x + 1, y)
+
+    """

Review Comment:
   Prefer to use the docstring since we will generate the docs in separate 
page. 
   
   Fixed the docs for the reutrn value.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to