gemini-code-assist[bot] commented on code in PR #18511:
URL: https://github.com/apache/tvm/pull/18511#discussion_r2564456325
##########
python/tvm/ir/instrument.py:
##########
@@ -285,6 +291,48 @@ def run_after_pass(self, mod, info):
class PrintBeforeAll:
"""Print the name of the pass, the IR, only before passes execute."""
- def run_before_pass(self, mod, info):
+ def run_before_pass(self, mod: IRModule, info: PassInfo):
print(f"Before Running Pass: {info}")
print(mod)
+
+
+@pass_instrument
+class DumpIR:
+ """Dump the IR after the pass runs."""
+
+ def __init__(self, dump_dir: Path | str, refresh: bool = False):
+ if isinstance(dump_dir, Path):
+ self.dump_dir = dump_dir
+ else:
+ self.dump_dir = Path(dump_dir)
+ self.counter = 0
+ if refresh and self.dump_dir.exists():
Review Comment:

The check `self.dump_dir.exists()` is not sufficient here. If `dump_dir`
points to an existing file, `exists()` returns true, and
`_safe_remove_dump_dir` will be called. This will lead to a
`NotADirectoryError` when `self.dump_dir.iterdir()` is called, causing a crash.
Using `self.dump_dir.is_dir()` is safer as it ensures the path is a directory
before attempting to iterate over its contents.
```suggestion
if refresh and self.dump_dir.is_dir():
```
##########
python/tvm/ir/instrument.py:
##########
@@ -285,6 +291,48 @@ def run_after_pass(self, mod, info):
class PrintBeforeAll:
"""Print the name of the pass, the IR, only before passes execute."""
- def run_before_pass(self, mod, info):
+ def run_before_pass(self, mod: IRModule, info: PassInfo):
print(f"Before Running Pass: {info}")
print(mod)
+
+
+@pass_instrument
+class DumpIR:
+ """Dump the IR after the pass runs."""
+
+ def __init__(self, dump_dir: Path | str, refresh: bool = False):
+ if isinstance(dump_dir, Path):
+ self.dump_dir = dump_dir
+ else:
+ self.dump_dir = Path(dump_dir)
+ self.counter = 0
+ if refresh and self.dump_dir.exists():
+ self._safe_remove_dump_dir()
+
+ def _safe_remove_dump_dir(self):
+ """Remove dump directory only if it contains only dumped IR files."""
+ # Pattern for dumped files: {counter:03d}_{pass_name}.py
+ dump_pattern = re.compile(r"^\d{3}_.*\.py$")
+
+ # Check all files in the directory
+ for item in self.dump_dir.iterdir():
+ # If there's a subdirectory or a file that doesn't match the
pattern, abort
+ if item.is_dir() or not dump_pattern.match(item.name):
+ print(
+ f"WARNING: Skipping removal of {self.dump_dir} as it
contains "
+ f"non-dumped files or directories. Please clean it
manually."
+ )
+ return
+
+ # Safe to remove - only contains dumped files
+ shutil.rmtree(self.dump_dir, ignore_errors=True)
+
+ def run_after_pass(self, mod: IRModule, info: PassInfo):
+ self.dump_dir.mkdir(parents=True, exist_ok=True)
+ try:
+ with open(self.dump_dir / f"{self.counter:03d}_{info.name}.py",
"w") as f:
+ f.write(mod.script())
Review Comment:

The pass name `info.name` is used directly to construct a filename. Pass
names may contain characters that are invalid in filenames on certain operating
systems (e.g., `/`, `:`, `*`). This could cause `open()` to fail with an
`OSError`. To prevent this, you should sanitize the pass name to replace or
remove any invalid characters.
```suggestion
sanitized_pass_name = re.sub(r'[<>:"/\\|?*]', '_', info.name)
with open(self.dump_dir /
f"{self.counter:03d}_{sanitized_pass_name}.py", "w") as f:
f.write(mod.script())
```
##########
python/tvm/ir/instrument.py:
##########
@@ -285,6 +291,48 @@ def run_after_pass(self, mod, info):
class PrintBeforeAll:
"""Print the name of the pass, the IR, only before passes execute."""
- def run_before_pass(self, mod, info):
+ def run_before_pass(self, mod: IRModule, info: PassInfo):
print(f"Before Running Pass: {info}")
print(mod)
+
+
+@pass_instrument
+class DumpIR:
+ """Dump the IR after the pass runs."""
+
+ def __init__(self, dump_dir: Path | str, refresh: bool = False):
+ if isinstance(dump_dir, Path):
+ self.dump_dir = dump_dir
+ else:
+ self.dump_dir = Path(dump_dir)
+ self.counter = 0
+ if refresh and self.dump_dir.exists():
+ self._safe_remove_dump_dir()
+
+ def _safe_remove_dump_dir(self):
+ """Remove dump directory only if it contains only dumped IR files."""
+ # Pattern for dumped files: {counter:03d}_{pass_name}.py
+ dump_pattern = re.compile(r"^\d{3}_.*\.py$")
+
+ # Check all files in the directory
+ for item in self.dump_dir.iterdir():
+ # If there's a subdirectory or a file that doesn't match the
pattern, abort
+ if item.is_dir() or not dump_pattern.match(item.name):
+ print(
+ f"WARNING: Skipping removal of {self.dump_dir} as it
contains "
+ f"non-dumped files or directories. Please clean it
manually."
+ )
+ return
+
+ # Safe to remove - only contains dumped files
+ shutil.rmtree(self.dump_dir, ignore_errors=True)
Review Comment:

Using `ignore_errors=True` in `shutil.rmtree` can suppress important errors,
such as permission issues, making it harder to diagnose problems. It's better
to handle potential exceptions explicitly and provide feedback to the user by
logging a warning.
```suggestion
try:
shutil.rmtree(self.dump_dir)
except OSError as e:
print(f"WARNING: Failed to remove directory {self.dump_dir}:
{e}")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]