Hi everyone,

I was trying to obtain the execution time for each one of the layers in 
resnet-18 (after auto-tuning). I obtain very similar results to the ones you 
obtain when running the whole architecture in the tutorial for the GPU 
(~1.10ms).

However, when I optimize a single layer and apply the best schedule I observe 
poor performance. For instance, for the first layer of resnet-18, I obtain 
0.25ms, which is the same I observe in the fallback configuration for the first 
layer.

When I checked the log file, there seems to be a configuration that performs 
better

`No: 73 GFLOPS: 2176.36/2176.36 result: 
MeasureResult(costs=(0.00010845095488215488,),`

But I think it may not be being applied.

The code I execute is the following:

    import os
    import sys
    import numpy as np

    import tvm
    import topi
    import logging
    from tvm import autotvm
    from tvm import relay
    from tvm.relay import testing
    from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, 
GridSearchTuner
    from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
    import tvm.contrib.graph_runtime as runtime

    # Details about the target (CPU/GPU)
    target = 'cuda'
    target_host  = 'llvm'
    batch_size = 1
    dtype = 'float32'
    # Set number of threads used for tuning based on the number of physical CPU 
cores on your machine.
    num_threads = 8
    os.environ["TVM_NUM_THREADS"] = str(num_threads)

    # Set the input name of the graph
    input_name = "data"

    # Arguments to create task

    log_file = "conv2d_224_cuda.log" 
    graph_opt_sch_file = "conv2d_224_cuda_opt.log"
    data_shape = (batch_size, 3, 224, 224)
    data_shape_type = data_shape + ('float32',)
    kernel_shape = (64, 3, 7, 7)
    out_shape = (batch_size, 64, 56, 56)

    data_shape_type = data_shape + ('float32',)
    kernel_shape_type = kernel_shape + ('float32',)
    kernel_size = (kernel_shape[2], kernel_shape[3])

    strides = (2,2)
    padding = (3,3,3,3)
    dilation = (1,1)

    # Convolution parameters
    args= (('TENSOR', data_shape, 'float32'), ('TENSOR', kernel_shape, 
'float32'), strides, padding, dilation, 'NCHW', 'float32')

    # Workload for the task
    workload = ('conv2d', data_shape_type, kernel_shape_type, strides, padding, 
dilation, 'NCHW', 'float32')

    data = relay.var("data", shape=data_shape, dtype=dtype)
    kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype)

    # Create a module with given target and extract task from it for auto-tuning
    ctx = tvm.gpu()
    out = relay.nn.conv2d(data, kernel, strides=strides, padding=padding, 
dilation=dilation, channels = kernel_shape[0], kernel_size = kernel_size, 
data_layout='NCHW', out_dtype=dtype)
    mod = relay.Module.from_expr(out)

    kernel_weights = tvm.nd.array(np.ones(kernel_shape, dtype=dtype), ctx)
    dict_params = {'kernel': kernel_weights}

    # task is a list an has several positions, autotuning has to get the 
position itself (e.g. task[0]) 
    task = autotvm.task.extract_from_program(mod, target=target, 
target_host=target_host, params=dict_params, ops=(relay.op.nn.conv2d,))
    # task[0] = autotvm.task.create(task[0].name, task[0].args, task[0].target, 
task[0].target_host, 'direct')
    # Define type of auto-tuner
    tuner_obj = XGBTuner(task[0])
    print(task[0])

    # logging config (for printing tuning log to the screen)
    logging.getLogger('autotvm').setLevel(logging.DEBUG)
    logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))

    # We measure 20 times and take average to reduce variance.
    measure_option = autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(number=20, repeat=3, 
min_repeat_ms=100,timeout=4))

    #n_trial = len(task.config_space)
    #print(n_trial)
    n_trial = 100
    """tuner_obj.tune(n_trial=n_trial,
               early_stopping = None,
               measure_option=measure_option,
               callbacks=[autotvm.callback.log_to_file(log_file)])"""

    # inspect the best config
    dispatch_context = autotvm.apply_history_best(log_file)
    best_config = dispatch_context.query(task[0].target, task[0].workload)
    print("\nBest config:")
    print(best_config)

    # Save optimal config to log file
    text_file = open(graph_opt_sch_file, "w")
    text_file.write(str(best_config))
    text_file.close()

    # create a module to apply best schedule
    out = relay.nn.conv2d(data, kernel, strides=strides, padding=padding, 
dilation=dilation, channels = kernel_shape[0], kernel_size = kernel_size, 
data_layout='NCHW', out_dtype=dtype)
    mod = relay.Module.from_expr(out)
    print(mod)

    # compile kernels with history best records
    with autotvm.apply_history_best(graph_opt_sch_file):
        ctx = tvm.gpu()
        print("Compile...")

        with relay.build_config(opt_level=4):

            kernel_weights = tvm.nd.array(np.ones(kernel_shape, dtype=dtype), 
ctx)
            dict_params = {'kernel': kernel_weights}
            graph, lib, params = relay.build_module.build(mod, params = 
dict_params, target=target, target_host = target_host)
            #print(params)
            #print(dict_params)


        # BENCHMARKING: Measure time with and without optimizations
        # upload parameters to device
        input_name = "data"

        data_tvm = 
tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype),ctx)
        module = runtime.create(graph, lib, ctx)
        module.set_input(input_name, data_tvm)
        module.set_input(**params)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", ctx, number=10, 
repeat=1000)

        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        #print(prof_res)
        print("Mean inference time auto-tuning (std dev): %.2f ms (%.2f ms)" % 
(np.mean(prof_res), np.std(prof_res)))


    out1 = relay.nn.conv2d(data, kernel, strides=strides, padding=padding, 
dilation=dilation, channels = kernel_shape[0], kernel_size = kernel_size, 
data_layout='NCHW', out_dtype=dtype)
    mod1 = relay.Module.from_expr(out1)
    #print(mod)

    ctx1 = tvm.gpu()
    graph1, lib1, params1 = relay.build_module.build(mod1, params = 
dict_params, target=target)
    #print(params)

    # BENCHMARKING: Measure time with and without optimizations
    input_name = "data"
    data_tvm = 
tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype),ctx1)
    module1 = runtime.create(graph1, lib1, ctx1)
    module1.set_input(input_name, data_tvm)
    module1.set_input(**params1)

    # evaluate
    print("Evaluate inference time cost...")
    ftimer = module1.module.time_evaluator("run", ctx, number=10, repeat=1000)

    prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
    #print(prof_res)
    print("Mean inference time fallback (std dev): %.2f ms (%.2f ms)" % 
(np.mean(prof_res), np.std(prof_res)))

I was wondering if you know whether I am missing something. I have tested a 
similar program using winograd convolution (just one convolutional layer) and I 
do see performance improvement.

I appreciate any help you can provide on this issue.





---
[Visit 
Topic](https://discuss.tvm.ai/t/relay-conv2d-layer-performance-after-auto-tuning-same-as-fallback/6888/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/be18f6d2561a1a61f1965e1cc0e3152c35086d1046be46031e32ef240a535bea).

Reply via email to