Hi, 

I was trying to import NCF by MXNet relay. Check this: [MXNet 
NCF](https://github.com/apache/incubator-mxnet/tree/master/example/neural_collaborative_filtering)

There is an unsupported operator: LogisticRegressionOutput

According to the SigmoidBinaryCrossEntropyLoss code: 
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/loss.py#L229

I tried to implement it by myself.

```
def _mx_logistic_regression_output(inputs, attrs):
    label = inputs[1]
    pred = inputs[0]
    # We use the stable formula: 
    # max(pred, 0) - pred * label + log(1 + exp(-abs(pred)))
    one = _expr.const(1, dtype="float32")
    exp_neg_abs_x = _op.exp(_op.negative(_op.abs(pred)))  # exp(-abs(pred))
    soft_relu = _op.add(_op.log(_op.add(one, exp_neg_abs_x)), _op.nn.relu(pred))
    loss = _op.nn.relu(pred) - pred * label + soft_relu
    return loss
```

and test the function by the following code:
```
def test_forward_logistic_regression():
    m = mx.nd.ones((3, 4))
    n = mx.nd.zeros((3, 4))
    vm = mx.sym.Variable('m')
    vn = mx.sym.Variable('n')

    out = mx.sym.LogisticRegressionOutput(data=vm, label=vn)
    exec_ = out.bind(mx.cpu(), {'m': m, 'n': n})
    exec_.forward()
    exec_.outputs[0].asnumpy()

    shape_dict = {"m": m.shape, "n": n.shape}
    mod, params = relay.frontend.from_mxnet(out, shape_dict)
    intrp = relay.create_executor("graph", mod=mod, target='llvm')
    op_res = intrp.evaluate()(m.asnumpy(), n.asnumpy())

    print(op_res)
    print(exec_.outputs[0].asnumpy())
```

TVM result is different from MXNet result. I also tried other formulas. Still 
not work. I'm very confusing how to implement it. 
Does anyone know how to fix this? Thanks!!!





---
[Visit 
Topic](https://discuss.tvm.ai/t/unsupported-op-logisticregressionoutput-when-import-ncf-model-from-mxnet/6863/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/656eba5617af4702804aa297235e1deebc1f4b18fd988534832999d34110e3be).

Reply via email to