I'm sorry that I didn't reply you in time because there are too many courses recently Here's the code: import tvm
from tvm import relay import numpy as np from tvm.contrib.download import download_testdata # PyTorch imports import torch import torchvision ###################################################################### # Load a pretrained PyTorch model # ------------------------------- model_name = "resnet18" model = getattr(torchvision.models, model_name)(pretrained=True) model = model.eval() # We grab the TorchScripted model via tracing input_shape = [1, 3, 224, 224] input_data = torch.randn(input_shape) scripted_model = torch.jit.trace(model, input_data).eval() ###################################################################### # Load a test image # ----------------- # Classic cat example! from PIL import Image import PIL img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" img_path = "C:\\Users\\Batman\\Desktop\\Ⅲ\\CV\\cat.png" #download_testdata(img_url, "cat.png", module="data") img = Image.open(img_path).resize((224, 224)) # Preprocess the image and convert to tensor from torchvision import transforms my_preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) img = np.array(img) print(np.array(img[:,:,0:3]).shape) img = Image.fromarray(img[:,:,0:3]) img = my_preprocess(img) img = np.expand_dims(img, 0) ###################################################################### # Import the graph to Relay # ------------------------- # Convert PyTorch graph to Relay graph. The input name can be arbitrary. input_name = "input0" shape_list = [(input_name, img.shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) ###################################################################### # Relay Build # ----------- # Compile the graph to llvm target with given input specification. target = tvm.target.Target("llvm", host="llvm") dev = tvm.cpu(0) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) ###################################################################### # Execute the portable graph on TVM # --------------------------------- # Now we can try deploying the compiled model on target. from tvm.contrib import graph_executor dtype = "float32" m = graph_executor.GraphModule(lib["default"](dev)) # Set inputs m.set_input(input_name, tvm.nd.array(img.astype(dtype))) # Execute m.run() # Get outputs tvm_output = m.get_output(0) print(tvm_output.shape) ##################################################################### # Look up synset name # ------------------- # Look up prediction top 1 index in 1000 class synset. synset_url = "".join( [ "https://raw.githubusercontent.com/Cadene/", "pretrained-models.pytorch/master/data/", "imagenet_synsets.txt", ] ) synset_name = "imagenet_synsets.txt" synset_path = download_testdata(synset_url, synset_name, module="data") with open(synset_path) as f: synsets = f.readlines() synsets = [x.strip() for x in synsets] splits = [line.split(" ") for line in synsets] key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits} class_url = "".join( [ "https://raw.githubusercontent.com/Cadene/", "pretrained-models.pytorch/master/data/", "imagenet_classes.txt", ] ) class_name = "imagenet_classes.txt" class_path = download_testdata(class_url, class_name, module="data") with open(class_path) as f: class_id_to_key = f.readlines() class_id_to_key = [x.strip() for x in class_id_to_key] # Get top-1 result for TVM top1_tvm = np.argmax(np.array(tvm_output)) tvm_class_key = class_id_to_key[top1_tvm] # Convert input to PyTorch variable and get PyTorch result for comparison with torch.no_grad(): torch_img = torch.from_numpy(img) output = model(torch_img) # Get top-1 result for PyTorch top1_torch = np.argmax(output.numpy()) torch_class_key = class_id_to_key[top1_torch] print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key])) print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key])) --- [Visit Topic](https://discuss.tvm.apache.org/t/check-failed-reporter-asserteq-data-shape-data-shape-size-1-weight-shape-1-is-false-denserel-input-dimension-doesnt-match-data-shape-1-512-weight-shape-512-1000/11274/10) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/34cb9de455d982359b916f8144f0d91faac8ecb41a24c977cf6c579111a2eac1).