gemini-code-assist[bot] commented on code in PR #18523:
URL: https://github.com/apache/tvm/pull/18523#discussion_r2571372727
##########
src/relax/op/op_common.h:
##########
@@ -211,7 +211,15 @@ inline StructInfo InferStructInfoUnary(const Call& call,
const BlockBuilder& ctx
}
auto output_sinfo =
ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
output_sinfo->dtype = f_compute_out_dtype(input_sinfo);
- return TensorStructInfo(output_sinfo);
+ if (call->sinfo_args.size() > 0) {
+ auto defined_sinfo = call->sinfo_args[0].as<TensorStructInfoNode>();
+ auto shape = output_sinfo->GetShape();
+ ICHECK(shape.defined());
+ return TensorStructInfo(ShapeExpr(shape.value()), output_sinfo->dtype,
+ defined_sinfo->vdevice.value());
+ } else {
+ return TensorStructInfo(output_sinfo);
+ }
Review Comment:

The logic for propagating `vdevice` is unsafe. It directly accesses
`defined_sinfo->vdevice.value()` without checking if `defined_sinfo` is null or
if `vdevice` is defined. This could lead to a segmentation fault if
`call->sinfo_args[0]` is not a `TensorStructInfoNode` or if `vdevice` is not
set. Please add checks to prevent potential crashes.
```c
if (call->sinfo_args.size() > 0) {
if (const auto* defined_sinfo =
call->sinfo_args[0].as<TensorStructInfoNode>()) {
if (defined_sinfo->vdevice.defined()) {
auto shape = output_sinfo->GetShape();
ICHECK(shape.defined());
return TensorStructInfo(ShapeExpr(shape.value()),
output_sinfo->dtype,
defined_sinfo->vdevice.value());
}
}
}
return TensorStructInfo(output_sinfo);
```
##########
include/tvm/runtime/tensor.h:
##########
@@ -188,14 +189,25 @@ class Tensor : public tvm::ffi::Tensor {
*/
TVM_DLL static void CopyFromBytes(const DLTensor* to, void* from, size_t
nbytes,
TVMStreamHandle stream = nullptr);
+
+ TVM_DLL void SetScope(ffi::String scope);
+ TVM_DLL ffi::String GetScope() const;
+
+ protected:
+ /*!
+ * \brief The memory scope
+ * represents the underlaying scope information of device
Review Comment:

There is a typo in the comment. "underlaying" should be "underlying".
```c
* represents the underlying scope information of device
```
##########
src/relax/transform/static_plan_block_memory.cc:
##########
@@ -990,6 +1032,30 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("relax.transform.StaticPlanBlockMemory",
StaticPlanBlockMemory);
}
+PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array<PrimExpr> pshape, DataType
dtype,
+ VDevice vdevice) {
+ int image_row_align =
vdevice->target->GetAttr<Integer>("image_base_address_alignment")
+ .value_or(Integer(64))
+ ->value;
+
+ // TODO(Siva) Assuming no any dimensions for now.
+ struct Shape {
+ const ffi::Array<PrimExpr>& shape;
+ int64_t operator[](size_t i) const { return *tir::as_const_int(shape[i]); }
Review Comment:

The implementation of `Shape::operator[]` assumes that all shape dimensions
are constant integers by using `*tir::as_const_int`. This will cause a crash if
any dimension is a symbolic variable. While the TODO comment acknowledges this,
it's a critical issue that can lead to runtime failures. A check should be
added to ensure dimensions are constant, or the logic should be updated to
handle symbolic shapes gracefully.
```c
int64_t operator[](size_t i) const {
const auto* int_imm = shape[i].as<IntImmNode>();
ICHECK(int_imm) << "Shape dimensions must be constant for texture
memory size calculation.";
return int_imm->value;
}
```
##########
python/tvm/tir/transform/transform.py:
##########
@@ -893,6 +893,17 @@ def LowerOpaqueBlock():
return _ffi_api.LowerOpaqueBlock() # type: ignore
+def InjectTextureAlloc():
+ """Inject Texture Allocation Intrensic to make sure appropriate lowering
Review Comment:

There is a typo in the docstring. "Intrensic" should be "Intrinsic".
```suggestion
"""Inject Texture Allocation Intrinsic to make sure appropriate lowering
```
##########
include/tvm/tir/transform.h:
##########
@@ -773,6 +773,12 @@ TVM_DLL Pass DefaultGPUSchedule();
*/
TVM_DLL Pass UseAssumeToReduceBranches();
+/*!
+ * \brief Inject Texture Allocation intrensic.
Review Comment:

There is a typo in the comment. "intrensic" should be "intrinsic".
```c
* \brief Inject Texture Allocation intrinsic.
```
##########
python/tvm/dlight/adreno/layout_transform.py:
##########
@@ -0,0 +1,133 @@
+# Licensed to the apache software foundation (asf) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, unused-variable
+
+"Schedules for Texture Based Layout Transforms"
+from typing import List, Union
+
+from tvm import tir
+from tvm.target import Target
+from .. import analysis
+
+from .base import AdrenoScheduleRule
+
+
+class LayoutTransform(AdrenoScheduleRule):
+ """Texture based Layout Transform Dlight Schedule for Adreno"""
+
+ def __init__(self, use_op_name=True):
+ self.use_op_name = use_op_name
+
+ # TODO: Try using Coalesced Writes...
+ def apply( # pylint: disable=too-many-locals
+ self,
+ func: Union[tir.PrimFunc, tir.Schedule],
+ target: Target,
+ _: bool,
+ ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+ # pylint: disable=invalid-name
+ if not (isinstance(func, (tir.PrimFunc, tir.Schedule))) or not
self.is_target_available(
+ target
+ ):
+ return None
+
+ if isinstance(func, tir.PrimFunc):
+ sch = tir.Schedule(func)
+ sch.work_on("main")
+ elif isinstance(func, tir.Schedule):
+ sch = func
+
+ root_block = analysis.get_root_block(sch, sch.func_working_on)
+
+ if len(sch.get_child_blocks(root_block)) != 1:
+ return None
+
+ blk = sch.get_child_blocks(root_block)[0]
+ block_info = analysis.get_block_info(sch, blk)
+ if not (
+ (self.use_op_name and block_info.name == "te_layout_transform")
+ or (not self.use_op_name and block_info.is_layout_transform(sch))
+ ):
+ return None
+
+ read_buf, write_buf = (block_info.read_bufs[0],
block_info.write_bufs[0])
+ lps = block_info.get_loops()
+ lpv_read, lpv_write = (
+ read_buf.assoc_lps[-1],
+ write_buf.assoc_lps[-1],
+ )
+
+ if lpv_read is None or lpv_write is None:
+ return None
+
+ vlen_read, vlen_write = read_buf.get_vecsize(), write_buf.get_vecsize()
+ local_cache = sch.get(lpv_read) != sch.get(lpv_write) or vlen_read !=
vlen_write
+ block_loops = [
+ lp
+ for lp in lps
+ if sch.get(lp) != sch.get(lpv_read) and sch.get(lp) !=
sch.get(lpv_write)
+ ]
+ vec_loops = (
+ [lpv_read, lpv_write] if sch.get(lpv_read) != sch.get(lpv_write)
else (lpv_read,)
+ )
+ sch.reorder(*block_loops, *vec_loops)
+ # TODO: Additional Pragmas and stuff
+ if local_cache:
+ if sch.get(lpv_read) != sch.get(lpv_write):
+ blp_read, vlp_read = sch.split(
+ lpv_read, [None, vlen_read], preserve_unit_iters=True
+ )
+ blp_write, vlp_write = sch.split(
+ lpv_write, [None, vlen_write], preserve_unit_iters=True
+ )
+ sch.reorder(blp_read, blp_write, vlp_read, vlp_write)
+ block_loops += [blp_read, blp_write]
+ rblk = sch.cache_read(blk, 0, "local")
+ sch.compute_at(rblk, block_loops[-1], preserve_unit_loops=True)
+ sch.vectorize(sch.get_loops(rblk)[-1])
+ sch.vectorize(vlp_write)
+ else:
+ if vlen_read > vlen_write:
+ read_lp, vec_lp = sch.split(blk, [None, vlen_write],
preserve_unit_iters=True)
+ rblk = sch.cache_read(blk, 0, "local")
+ sch.compute_at(rblk, read_lp, preserve_unit_loops=True)
+ sch.vectorize(sch.get_loops(rblk)[-1])
+ sch.vectorize(vec_lp)
+ else:
+ rblk = sch.cache_read(blk, 0, "local")
+ sch.compute_at(rblk, block_loops[-1],
preserve_unit_loops=True)
+ _, vread_lp = sch.split(
+ sch.get_loops(rblk)[-1], vlen_read,
preserve_unit_iters=True
+ )
+ sch.vectorize(vread_lp)
+ sch.vectorize(vlp_write)
+ else:
+ blp, vlp = sch.split(lpv_read, [None, vlen_read],
preserve_unit_iters=True)
+ block_loops += [blp]
+ sch.vectorize(vlp)
+
+ b = sch.fuse(*block_loops)
+ tx_extent = min(sch.get(b).extent, 256)
+ candidates = [1, 2, 4, 8, 16, 32]
+ ux = sch.sample_categorical(
+ candidates, [1 / len(candidates) for _ in range(len(candidates))]
+ )
Review Comment:

The variable `ux` is assigned a value from `sch.sample_categorical` but is
never used. This unused variable should be removed.
##########
python/tvm/dlight/adreno/__init__.py:
##########
@@ -18,3 +18,8 @@
Adreno schedule rules.
"""
from .convolution import Conv2d
+from .layout_transform import LayoutTransform
+from .fallback import Fallback
+from .pool import Pool2D
+
+# from .fallback import Fallback
Review Comment:

This commented-out import appears to be dead code and should be removed to
keep the codebase clean.
##########
python/tvm/dlight/adreno/pool.py:
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+""" Pool schedule rule for Adreno operators."""
+
+from tvm import tir
+from tvm.target import Target
+
+from .base import AdrenoScheduleRule
+from .. import analysis
+
+
+# pylint: disable=invalid-name, unused-variable
+class Pool2D(AdrenoScheduleRule):
+ def apply( # pylint: disable=too-many-locals,missing-docstring
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> tir.Schedule:
+ sch = tir.Schedule(func)
+ root = sch.get_block(name="root", func_name="main")
+
+ blocks = sch.get_child_blocks(root)
+ blocks_names = [sch.get(blk).name_hint for blk in blocks]
+
+ if not "adaptive_pool_sum" in blocks_names and not "pool_max" in
blocks_names:
+ return None
+
+ def schedule_pad(blk: tir.schedule.BlockRV):
+ lps, veclp = sch.get_loops(blk)[:-1], sch.get_loops(blk)[-1]
+ sch.vectorize(veclp)
+ b = sch.fuse(*lps)
+ tx_extent = min(int(sch.get(b).extent) & ~int(sch.get(b).extent -
1), 256)
+ bx, tx = sch.split(b, [None, tx_extent])
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+
+ def schedule_max_pool(blk: tir.schedule.BlockRV):
+ block_info = analysis.get_block_info(sch, blk)
+ iters_kind = "".join([_iter.kind for _iter in block_info.iters])
+ if iters_kind != "SSSSSRR":
+ return None
+
+ lps = sch.get_loops(blk)
+ block_lps, vec_lp, red_lps = lps[:4], lps[4], lps[5:]
+ write_blk = sch.cache_write(blk, 0, "local")
+ sch.reverse_compute_at(write_blk, vec_lp)
+ b = sch.fuse(*block_lps)
+ tx_extent = min(int(sch.get(b).extent) & ~int(sch.get(b).extent -
1), 256)
+ bx, tx = sch.split(b, [None, tx_extent])
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec_lp)
+
+ return True
+
+ passed_reduction = False
+ for blk in blocks:
+ if sch.get(blk).name_hint == "pad_temp":
+ schedule_pad(blk)
+ elif (
+ sch.get(blk).name_hint == "adaptive_pool_sum"
+ or sch.get(blk).name_hint == "pool_max"
+ ):
+ ok = schedule_max_pool(blk)
+ if not ok:
+ return None
+ passed_reduction = True
+ else:
+ try:
+ if passed_reduction:
+ sch.reverse_compute_inline(blk)
+ else:
+ sch.compute_inline(blk)
+ except: # pylint: disable=bare-except
Review Comment:

Using a bare `except:` can mask unexpected errors and make debugging
difficult. It's better to catch a more specific exception, like `Exception` or
`tvm.error.TVMError`.
```suggestion
except Exception: # pylint: disable=broad-except-bare
```
##########
python/tvm/dlight/adreno/fallback.py:
##########
@@ -0,0 +1,177 @@
+# licensed to the apache software foundation (asf) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Dlight Adreno Fallback Schedules"""
+
+
+from typing import List, Union
+
+from tvm import tir
+from tvm.target import Target
+from .. import analysis
+from .base import AdrenoScheduleRule
+from .utils import get_texture_storage
+
+
+# pylint:
disable=invalid-name,missing-function-docstring,unused-variable,unused-import
+class Fallback(AdrenoScheduleRule):
+ """Texture Based Fallback Schedule(s) for Adreno"""
+
+ @staticmethod
+ def schedule_inline_blocks(
+ sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]
+ ) -> List[tir.schedule.BlockRV]:
+ """
+ Auto Inlines Injective and Element-wise Operations while trying to
omit data pad blocks...
+ """
+
+ if blocks is None:
+ root_blk = analysis.get_root_block(sch)
+ blocks = sch.get_child_blocks(root_blk)
+
+ remaining_blocks = []
+ for blk in blocks:
+ block_info = analysis.get_block_info(sch, blk)
+ if block_info.is_injective() and not block_info.is_data_pad(sch):
+ if len(block_info.consumers) == 1:
+ try:
+ sch.compute_inline(blk)
+ except Exception: # pylint: disable=broad-exception-caught
+ remaining_blocks.append(blk)
+ elif len(block_info.producers) == 1:
+ inlined_once = False
+ try:
+ # Would cause an issue inlining to producer with
multiple consumers
+ while (
+ len(sch.get_producers(blk)) == 1
+ and
len(sch.get_consumers(sch.get_producers(blk)[0])) == 1
+ ):
+ sch.reverse_compute_inline(blk)
+ inlined_once = True
+ except Exception: # pylint: disable=broad-exception-caught
+ break
+ if not inlined_once:
+ remaining_blocks.append(blk)
+ else:
+ remaining_blocks.append(blk)
+ else:
+ remaining_blocks.append(blk)
+ return remaining_blocks
+
+ @staticmethod
+ def schedule_annotate_storage(sch: tir.Schedule, func=get_texture_storage):
+ """Annotates intermediate buffers to textures whenever it's possible
to do so"""
+ return
+ # pylint: disable=unreachable
+ root_blk = analysis.get_root_block(sch)
+ blocks = sch.get_child_blocks(root_blk)
+
+ for blk in blocks:
+ block_info = analysis.get_block_info(sch, blk)
+ scope = func(block_info)
+ if scope is not None and len(sch.get_consumers(blk)) > 0:
+ sch.set_scope(blk, 0, scope)
Review Comment:

The `schedule_annotate_storage` function is currently unreachable due to an
early `return` statement. If this function is not yet implemented or used, it
should be removed to avoid dead code. If it is intended to be used, the
`return` statement should be removed.
##########
python/tvm/dlight/adreno/convolution.py:
##########
@@ -16,215 +16,92 @@
# under the License.
# pylint: disable=missing-docstring, invalid-name
"""A Conv2d schedule rule for Adreno GPU operators."""
-from dataclasses import dataclass
-from typing import List, Optional
+from typing import Optional, Union
from tvm import tir
from tvm.target import Target
-from tvm.tir import IterVar
-from tvm.tir.schedule.schedule import BlockRV
-from ..analysis import BlockInfo, IterInfo
+from .utils import schedule_inline_blocks, schedule_storage_annotate,
schedule_default
+from .. import analysis
from .base import AdrenoScheduleRule
-def is_spatial_block(sch: tir.Schedule, block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
- return iter_types == {IterVar.DataPar}
-
-
-def is_reduction_block(sch: tir.Schedule, block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
- return iter_types == {IterVar.CommReduce, IterVar.DataPar}
-
-
-def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
- result = []
- for producer in sch.get_producers(block):
- result.append(producer)
- result.extend(_collect_producers(sch, producer))
- return result
-
-
-def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
- result = []
- for consumer in sch.get_consumers(block):
- result.append(consumer)
- result.extend(_collect_consumers(sch, consumer))
- return result
-
-
-def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) ->
BlockInfo:
- def _iter_kind(loop: tir.IterVar) -> str:
- return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce:
"R"}.get(loop.iter_type, "O")
-
- def _is_reduction_block(block: tir.schedule.BlockRV):
- for iter_var in sch.get(block).iter_vars:
- if _iter_kind(iter_var) == "R":
- return True
- return False
-
- return BlockInfo(
- name=sch.get(block).name_hint,
- iters=[
- IterInfo(
- kind=_iter_kind(iter_var),
- var=iter_var.var,
- dom=iter_var.dom.extent,
- loop_rv=loop_rv,
- )
- for loop_rv, iter_var in zip(sch.get_loops(block),
sch.get(block).iter_vars)
- ],
- block_rv=block,
- reduction_block=_is_reduction_block(block),
- )
-
-
-def get_reduction_blocks(sch: tir.Schedule, blocks:
List[tir.schedule.BlockRV]) -> bool:
- # NOTE: We assume there is only one reduction block in the function
- # all blocks are required to be spatial or reduction
- if not all(
- [is_reduction_block(sch, block) or is_spatial_block(sch, block) for
block in blocks]
- ):
- return None
-
- # There is only one reduction block
- reduction_blocks = [block for block in blocks if is_reduction_block(sch,
block)]
- if len(reduction_blocks) != 1:
- return None
-
- return reduction_blocks[0]
-
-
-def is_convolution(sch: tir.Schedule, block: tir.schedule.BlockRV):
- # TODO: Use buffer access patterns to discover convolution type kernels
instead of using name.
- return (
- sch.get(block).name_hint.count("conv2d_NCHWc_OIHWo")
- and "".join([iter_type.kind for iter_type in get_block_info(sch,
block).iters])
- == "SSSSSRRR"
- )
-
-
class Conv2d(AdrenoScheduleRule):
"""The schedule rule for convolution computation"""
- @dataclass
- class Config:
- block_size_x: int = 8
- block_size_y: int = 8
- vector_size: int = 1
- unroll: int = 256 # 0 means no unroll
- use_shared: bool = True
- storage_align: bool = False
- inner_x: bool = False
-
- def get_configs(self, target: Target) -> Config:
- """Get the schedule config for the target"""
- if target.kind.name == "cuda" or target.kind.name == "rocm":
- return Conv2d.Config(
- block_size_x=8,
- block_size_y=16,
- vector_size=2,
- unroll=256,
- use_shared=True,
- storage_align=True,
- inner_x=False,
- )
- elif target.kind.name == "opencl" and (
- ("android" in str(target.host)) or ("adreno" in str(target.attrs))
- ):
- return Conv2d.Config(
- block_size_x=32,
- block_size_y=4,
- vector_size=8,
- unroll=16,
- use_shared=False,
- storage_align=False,
- inner_x=True,
- )
- else:
- return Conv2d.Config()
+ @staticmethod
+ def schedule_conv2d(sch: tir.Schedule, blk: tir.schedule.BlockRV):
+ # TODO: Loop Pattern mayn't be reliable, need to perform better
analysis.
+ n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk)
+
+ # bz, vz, tz = sch.split(oc, sch.sample_perfect_tile(oc, 3, 32))
+ # by, vy, ty = sch.split(oh, sch.sample_perfect_tile(oh, 3, 32))
+ # bx, vx, tx = sch.split(ow, sch.sample_perfect_tile(ow, 3, 32))
+
+ bz, vz, tz = sch.split(oc, [None, 8, 1], preserve_unit_iters=True)
+ by, vy, ty = sch.split(oh, [None, 1, 16], preserve_unit_iters=True)
+ bx, vx, tx = sch.split(ow, [None, 1, 16], preserve_unit_iters=True)
+
+ bz = sch.fuse(n, bz, preserve_unit_iters=True)
+ sch.reorder(bz, by, bx, vz, vy, vx, tz, ty, tx, ob)
+ sch.bind(bz, "blockIdx.z")
+ sch.bind(by, "blockIdx.y")
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(vz, "vthread.z")
+ sch.bind(vy, "vthread.y")
+ sch.bind(vx, "vthread.x")
+ sch.bind(tz, "threadIdx.z")
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+
+ rblk = sch.cache_read(blk, 0, "local")
+ ico, icb = sch.split(ic, [None, 4], preserve_unit_iters=True)
+ sch.reorder(ico, kh, kw, icb, ob)
+
+ sch.compute_at(rblk, kw, preserve_unit_loops=True)
+ sch.vectorize(sch.get_loops(rblk)[-1])
+ wblk = sch.cache_write(blk, 0, "local")
+ sch.reverse_compute_at(wblk, tx, preserve_unit_loops=True)
+ sch.vectorize(sch.get_loops(wblk)[-1])
+ init_blk = sch.decompose_reduction(blk, tx)
+ sch.vectorize(sch.get_loops(init_blk)[-1])
def apply( # pylint: disable=too-many-locals,missing-docstring
self,
- func: tir.PrimFunc,
+ func: Union[tir.PrimFunc],
Review Comment:

The type hint for `func` is `Union[tir.PrimFunc]`, but the implementation on
line 75 checks for `isinstance(func, (tir.PrimFunc, tir.Schedule))`. The type
hint should be updated to `Union[tir.PrimFunc, tir.Schedule]` to accurately
reflect the function's capabilities.
```suggestion
func: Union[tir.PrimFunc, tir.Schedule],
```
##########
tests/python/relax/adreno/test_transform_annotate_custom_scope.py:
##########
@@ -106,6 +106,7 @@ def verify(mod, expected):
mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod)
mod = tvm.relax.transform.Normalize()(mod)
+ print(mod)
Review Comment:

This `print(mod)` statement appears to be for debugging purposes and should
be removed from the final test code.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]