This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 474cde494d [Optimization][Operator] Implement and enable
Conv2d-Reshape-Add-ReLU fusion (#18240)
474cde494d is described below
commit 474cde494dac47cab60d8cd4181ae22bfee98bf5
Author: kimm240 <[email protected]>
AuthorDate: Mon Mar 16 13:10:50 2026 +0900
[Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU
fusion (#18240)
This commit extends the make_fused_bias_activation_pattern function to
support
PyTorch frontend's specific IR generation pattern for convolution
operations
with bias. When PyTorch models with bias=True are converted to Relax IR,
the
frontend generates a conv2d -> reshape -> add -> relu sequence instead
of the
simpler conv2d -> add -> relu pattern that existing fusion logic
expected.
The key changes include:
1. Add allow_reshape parameter to make_fused_bias_activation_pattern in
both
dpl/pattern.py and backend/patterns.py with default value False to
maintain
backward compatibility.
2. When allow_reshape=True, the pattern matcher now recognizes and fuses
the
complete conv2d -> reshape -> add -> relu sequence into a single
composite
function, eliminating intermediate tensor allocations and kernel launch
overhead.
3. The original pattern (allow_reshape=False) only fuses conv2d -> add
-> relu,
leaving the reshape operation outside the fused function, which results
in
suboptimal performance for PyTorch-originated models.
This enhancement enables more efficient operator fusion for PyTorch
models,
reducing memory usage and improving execution performance by capturing
the
complete computation pattern in a single fused kernel. The
implementation
maintains full backward compatibility while extending support for
PyTorch
frontend's specific IR generation patterns.
Comprehensive tests are added to verify the fusion behavior with both
old and
new patterns, ensuring correctness across different convolution types
(Conv1d,
Conv2d, Conv3d) and validating that fusion only occurs when appropriate
conditions are met.
---------
Co-authored-by: kim hyun gyu <[email protected]>
---
python/tvm/relax/backend/patterns.py | 10 +-
python/tvm/relax/dpl/pattern.py | 13 +-
.../relax/test_fuse_pytorch_conv2d_bias_pattern.py | 158 +++++++++++++++++++++
3 files changed, 177 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 9d47bb8354..c9c853d702 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -37,10 +37,15 @@ def _with_bias_activation_pattern(
annotations: dict[str, DFPattern],
with_bias: bool = False,
activation: str | None = None,
+ allow_reshape: bool = False,
) -> tuple[DFPattern, Mapping[str, DFPattern]]:
if with_bias:
annotations["bias"] = bias = wildcard()
- out = is_op("relax.add")(out, bias)
+ if allow_reshape:
+ reshaped_bias = is_op("relax.reshape")(bias, wildcard(),
varg_default_wildcard=True)
+ out = is_op("relax.add")(out, reshaped_bias,
varg_default_wildcard=True)
+ else:
+ out = is_op("relax.add")(out, bias)
if activation:
out = is_op(activation)(out)
@@ -52,6 +57,7 @@ def make_fused_bias_activation_pattern(
op_name: str,
with_bias: bool = False,
activation: str | None = None,
+ allow_reshape: bool = False,
) -> tuple[DFPattern, Mapping[str, DFPattern]]:
"""
A simple utility to create patterns for an operation fused with bias
addition and activation.
@@ -82,7 +88,7 @@ def make_fused_bias_activation_pattern(
out = is_op(op_name)(lhs, rhs)
annotations = {"lhs": lhs, "rhs": rhs, "root": out}
- return _with_bias_activation_pattern(out, annotations, with_bias,
activation)
+ return _with_bias_activation_pattern(out, annotations, with_bias,
activation, allow_reshape)
def make_residual_block_pattern(
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index ad7d1b7421..f485d8bbbd 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -1121,7 +1121,9 @@ def _only_used_by(lhs: DFPattern | PatternSeq, rhs:
DFPattern | PatternSeq, inde
return ffi.only_used_by(lhs, rhs, index) # type: ignore
-def make_fused_bias_activation_pattern(op_name, with_bias=False,
activation=None):
+def make_fused_bias_activation_pattern(
+ op_name, with_bias=False, activation=None, allow_reshape=False
+):
"""
A simple utility to create patterns for an operation fused with bias
addition and activation.
@@ -1136,6 +1138,9 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
activation: str
The name of an activation Relax op, such as "relax.nn.relu"
+ allow_reshape: bool
+ Whether to allow reshape operation before bias addition (for PyTorch
frontend)
+
Returns
-------
pattern: DFPattern
@@ -1147,7 +1152,11 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
if with_bias:
bias = wildcard()
- out = is_op("relax.add")(out, bias)
+ if allow_reshape:
+ reshaped_bias = is_op("relax.reshape")(bias, wildcard(),
varg_default_wildcard=True)
+ out = is_op("relax.add")(out, reshaped_bias,
varg_default_wildcard=True)
+ else:
+ out = is_op("relax.add")(out, bias)
if activation:
return is_op(activation)(out)
diff --git a/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
b/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
new file mode 100644
index 0000000000..ca9736af9f
--- /dev/null
+++ b/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import torch
+
+from tvm import relax
+from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern
+from tvm.relax.frontend.torch import from_fx
+
+
+def test_conv2d_bias_relu_fusion():
+ """Test PyTorch conv2d + bias + relu fusion with reshape pattern"""
+
+ class Conv2dBiasRelu(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ return self.relu(self.conv(x))
+
+ # Convert PyTorch model to Relax IR
+ model = Conv2dBiasRelu()
+ graph_model = torch.fx.symbolic_trace(model)
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+
+ # Apply fusion with modified pattern
+ patterns = [
+ (
+ "conv2d_bias_activation_with_reshape",
+ make_fused_bias_activation_pattern(
+ "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu",
allow_reshape=True
+ ),
+ )
+ ]
+
+ fused_mod = relax.transform.FuseOpsByPattern(patterns,
bind_constants=False)(mod)
+
+ # Verify fusion occurred
+ fused_functions = [name for name in fused_mod.functions.keys() if "fused"
in str(name)]
+
+ assert len(fused_functions) == 1, "Expected exactly one fused function"
+
+ # Verify the fused function contains all operations
+ fused_func = fused_mod[fused_functions[0]]
+ assert hasattr(fused_func, "attrs"), "Fused function should have
attributes"
+ assert "Composite" in fused_func.attrs, "Fused function should have
Composite attribute"
+
+
+def test_conv2d_bias_relu_fusion_comparison():
+ """Compare fusion with and without allow_reshape option"""
+
+ class Conv2dBiasRelu(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ return self.relu(self.conv(x))
+
+ model = Conv2dBiasRelu()
+ graph_model = torch.fx.symbolic_trace(model)
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+
+ # Test with allow_reshape=False
+ old_patterns = [
+ (
+ "conv2d_bias_activation_old",
+ make_fused_bias_activation_pattern(
+ "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu",
allow_reshape=False
+ ),
+ )
+ ]
+
+ old_fused_mod = relax.transform.FuseOpsByPattern(old_patterns,
bind_constants=False)(mod)
+
+ # Test with allow_reshape=True
+ new_patterns = [
+ (
+ "conv2d_bias_activation_new",
+ make_fused_bias_activation_pattern(
+ "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu",
allow_reshape=True
+ ),
+ )
+ ]
+
+ new_fused_mod = relax.transform.FuseOpsByPattern(new_patterns,
bind_constants=False)(mod)
+
+ # Both should create fused functions
+ old_fused_functions = [name for name in old_fused_mod.functions.keys() if
"fused" in str(name)]
+ new_fused_functions = [name for name in new_fused_mod.functions.keys() if
"fused" in str(name)]
+
+ assert len(old_fused_functions) >= 1, "Old pattern should create at least
one fused function"
+ assert len(new_fused_functions) >= 1, "New pattern should create at least
one fused function"
+
+
+def test_conv2d_no_fusion_case():
+ """Test case where fusion should not occur"""
+
+ class Conv2dNoBias(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 3, bias=False)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ model = Conv2dNoBias()
+ graph_model = torch.fx.symbolic_trace(model)
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+
+ # Apply fusion pattern
+ patterns = [
+ (
+ "conv2d_bias_activation",
+ make_fused_bias_activation_pattern(
+ "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu",
allow_reshape=True
+ ),
+ )
+ ]
+
+ fused_mod = relax.transform.FuseOpsByPattern(patterns,
bind_constants=False)(mod)
+
+ # No fusion should occur
+ fused_functions = [name for name in fused_mod.functions.keys() if "fused"
in str(name)]
+
+ assert len(fused_functions) == 0, "No fusion should occur for conv2d
without bias and relu"
+
+
+if __name__ == "__main__":
+ test_conv2d_bias_relu_fusion()
+ test_conv2d_bias_relu_fusion_comparison()
+ test_conv2d_no_fusion_case()