I try to register a api to implement the gradient of relay.concatenate, Below is the code I implemented. but i think there is missing some data.
The following code is the c++ code I implemented for concatenate grad namespace tvm { namespace relay { //concatenate split bool ConcatenateGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types = [ograd, tensor_tuple, result] CHECK_EQ(types.size(), 3); **DLOG(INFO) << "input type: " << types <<std::endl;** const auto* ograd = types[0].as<TensorTypeNode>(); const auto* tensor_tuple = types[1].as<TupleTypeNode>(); DLOG(INFO) << "tensor_tuple: " << tensor_tuple->type_key() <<std::endl; DLOG(INFO) << "tensor_tuple fields type: " << tensor_tuple->fields->type_key() <<std::endl; DLOG(INFO) << "tensor_tuple size: " << tensor_tuple->fields.size() <<std::endl; DLOG(INFO) << "tensor_tuple 1: " << tensor_tuple->fields[0]->type_key() <<std::endl; DLOG(INFO) << "tensor_tuple 2: " << tensor_tuple->fields[1]->type_key() <<std::endl; const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); const auto param = attrs.as<ConcatenateAttrs>(); CHECK(param != nullptr); auto axis = param->axis; CHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" << ", but got axis = " << axis << ", and ndim = " << ndim; axis = axis < 0 ? ndim + axis : axis; // Calculate shape reporter->Assign(types[2], types[1]); return true; } Array<Tensor> ConcatenateGradCompute(const Attrs& attrs, const Array<Tensor>& inputs, const Type& out_type, const Target& target) { DLOG(INFO) << "Grad Compute" <<std::endl; const auto* param = attrs.as<ConcatenateAttrs>(); //const auto* param = attrs.as<SplitAttrs>(); CHECK(param != nullptr); auto axis = param->axis; const auto test = out_type.as<TupleTypeNode>(); int64_t num_sections = 2; return Array<Tensor>{ topi::split_sections(inputs[0], num_sections, param->axis) }; } Expr MakeConcatenateGrad(Expr ograd, Expr tensor_tuple, int axis ) { //auto attrs = make_node<SplitAttrs>(); auto attrs = make_node<ConcatenateAttrs>(); attrs->axis = axis; DLOG(INFO) << "tensor_tuple type: " << tensor_tuple->type_key() <<std::endl; const auto tensors = tensor_tuple.as<TupleTypeNode>(); static const Op& op = Op::Get("concatenate_grad"); return CallNode::make(op, {ograd, tensor_tuple}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.concatenate_grad") .set_body_typed(MakeConcatenateGrad); RELAY_REGISTER_OP("concatenate_grad") .describe(R"code(concatenate grad. )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ConcatenateAttrs") .set_num_inputs(2) .add_argument("ograd", "Tensor", "The gradient of output.") .add_argument("tensor_tuple", "Tensor", "The input list of tensors.") .set_support_level(1) .add_type_rel("ConcatenateGradRel", ConcatenateGradRel) .set_attr<FTVMCompute>("FTVMCompute", ConcatenateGradCompute) .set_attr<TOpPattern>("TOpPattern", kInjective); } // namespace relay } // namespace tvm I use relay.concatenate connected 2 matrices. one shape is 2x3 and the other is 3x3, the concatenate axis is 0, so the result is a 5x3 matrices. At the bottom there is my test case code. In line 13 of the code, I print out the type information, I hope i could get [TensorType([5, 3], float32), TupleTypeNode([TensorType([2, 3], float32), RefTypeNode(TensorType([2, 3], float32))], [TensorType([3, 3], float32), RefTypeNode(TensorType([3, 3], float32))]), IncompleteTypeNode(0, 0x3b4f2f0)] But actually what I got is `[IncompleteTypeNode(0, 0x50d7f50), TupleTypeNode([TensorType([2, 3], float32), RefTypeNode(TensorType([2, 3], float32))]), IncompleteTypeNode(0, 0x516ada0)]` i find that the type[1] value is just the first data of the concatenated matrices, there is missing a value of the 3x3 matrices. I want to know how to get the full value of the Tuple type. Below is a partial error log TVMError: Error(s) have occurred. We have annotated the program with them: In main: v0.0.1 fn (%x: Tensor[(2, 3), float32], %y: Tensor[(3, 3), float32]) { %0 = fn () -> () { () } let %x1 = ref(%0) %1 = zeros_like(%x) %2 = ref(%1) let %x2 = (%x, %2) %3 = zeros_like(%y) %4 = ref(%3) let %x3 = (%y, %4) %12 = fn (%x5: (Tensor[(2, 3), float32], ref(Tensor[(2, 3), float32])), %y1: (Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))) { let %x6 = (%x5, %y1) %5 = %x6.0 let %x7 = concatenate(%5)an internal invariant was violated while typechecking your program [18:00:48] /home/rui.huang/tvm-git/3rdparty/HalideIR/src/tvm/node/node.h:264: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed. ; an internal invariant was violated while typechecking your program [18:00:48] /home/rui.huang/tvm-git/3rdparty/HalideIR/src/tvm/node/node.h:264: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed. ; an internal invariant was violated while typechecking your program [18:00:48] /home/rui.huang/tvm-git/3rdparty/HalideIR/src/tvm/node/node.h:264: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed. ; an internal invariant was violated while typechecking your program [18:00:48] /home/rui.huang/tvm-git/3rdparty/HalideIR/src/tvm/node/node.h:264: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed. ; %6 = zeros_like(%x7) let %x8 = ref(%6) let %x9 = %x1^ %11 = fn () -> () { let %x11 = %x8^ %7 = %x6.1an internal invariant was violated while typechecking your program [18:00:48] /home/rui.huang/tvm-git/src/relay/pass/type_solver.cc:119: Check failed: resolved.defined(): Unable to unify parent types: TupleTypeNode([TensorType([3, 3], float32), RefTypeNode(TensorType([3, 3], float32))]) and RefTypeNode(IncompleteTypeNode(0, 0x3a00b10)) let %x12 = %7^ %8 = %x6.1an internal invariant was violated while typechecking your program [18:00:48] /home/rui.huang/tvm-git/src/relay/pass/type_solver.cc:119: Check failed: resolved.defined(): Unable to unify parent types: TupleTypeNode([TensorType([3, 3], float32), RefTypeNode(TensorType([3, 3], float32))]) and RefTypeNode(IncompleteTypeNode(0, 0x3bc3680)) ; %9 = concatenate_grad(%x11, %5)... In other operators, all operation parameters are single values, but in concatenate operations, relay.concatenate uses Tuple(list([x, y, z...])) as a parameter to pass forward operating. Concatenate's forward operation code is located _tvm/relay/op/tensor.py( line:719 )_ The following code is the concatenate grad interface I registered. @register_gradient("concatenate") def concatenate_grad(orig, grad): axis = orig.attrs.axis data = orig.args[0] return [_make.concatenate_grad(grad, data, axis)] The following code is my test case def relay_concatenate_grad(func, x, y, axis=0): dshape = (2, 3) oshape = (3, 3) dtype = 'float32' x = np.random.rand(*dshape).astype(dtype) y = np.random.rand(*oshape).astype(dtype) input_parameters = {'x': x, 'y': y} intrp = relay.create_executor(ctx=tvm.context('llvm', 0), target='llvm') x_var = relay.var('x', relay.TensorType(x.shape, 'float32')) y_var = relay.var('y', relay.TensorType(y.shape, 'float32')) fwd_func = relay.Function([x_var, y_var], relay.concatenate([x_var, y_var], axis)) op_res = intrp.evaluate(fwd_func)(**input_parameters) print("grad fun:\n", gradient(fwd_func)) bwd_func = infer_type(gradient(fwd_func)) _, grads = intrp.evaluate(bwd_func)(**input_parameters) print("relay_ grad:\n", grads) return grads --- [Visit Topic](https://discuss.tvm.ai/t/relay-concatenate-missing-data-when-development-the-gradient-of-relay-concatenate/3595/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/c16ae972300dba127f342101a08156da3ba673fbc4d280f539292bfc4eefe19d). Tianqi Chen, UW, Seattle, WA, 98105, United States http://tracking.discuss.tvm.ai/tracking/unsubscribe?msgid=iwr5OpZi_MfkRYrBpsRv8w2