Just to follow up on what @tqchen summarized previously, here's my 
understanding: 

### frontend converters
We want users who write frontend converters be aware that certain operators are 
stateful. We can encourage them to write these operations in A1 style. For 
instance: 

```
def _mx_dropout_train(inputs, attrs, module):
    rate = attrs.get_float("p", 0.5)
    global_state = module['prng']
    state_ref = relay.RefCreate(global_state)
    read_state = relay.RefRead(state_ref)
    # the dropout_train operator outputs both y and the new state 
    y_state = _op.nn.dropout_train(inputs[0], read_state, rate=rate)
    # write back new state, return y 
    write_state = relay.RefWrite(state_ref, y_state[1])
    y = relay.Let(relay.var('ref_write'), write_state, y_state[0])
    return y
```

where `module['prng']` is a global variable representing the PRNG state in the 
module. As of now, global variables currently are only used to represent 
functions. We need to extend it to represent the random state, too. 

### rewriting A1-style programs to A2 -style ones

Let's say we have a function below with stateful ops:
```
def @func1(%x) {
  %0 = ref(@prng_state);
  %1 = %0^;
  %2 = nn.dropout_train(%x, %1, rate=0.7f)
  %3 = %2.1;
  let %ref_write: () = (%0 := %3);
  %2.0
}
```
In the rewriting pass, we detect that the global random state is used, and 
replace its references to the following: 
```
def @func1_rewritten(%x, %state) {
  %2 = nn.dropout_train(%x, %state, rate=0.7f)
  (%2.0, %2.1)
}
```
Note that the function output type is changed to a tuple containing the new 
state. Meanwhile we need to update all CallNodes for this function accordingly. 
Here is another example: 

```
def @long_func(%x) {
  %0 = ref(@prng_state);
  %1 = %0^;
  %2 = nn.dropout_train(%x, %1, rate=0.7f)
  %3 = %2.1;
  %4 = (
    let %ref_write1: () = (%0 := %3);
    %2.0
  );
  %5 = %0^;
  %6 = nn.dropout_train(%4, %5, rate=0.1f) 
  %7 = %6.1;
  let %ref_write: () = (%0 := %7);
  %6.0
}

===> 

def @long_func_rewritten(%x, %state) {
  %2 = nn.dropout_train(%x, %state, rate=0.7f)
  %3 = %2.1;
  %4 = %2.0;
  %6 = nn.dropout_train(%4, %3, rate=0.1f) 
  (%6.1, %6.0)
}

```

Note that the pass implementation requires tracking the latest value of the 
global variable within each scope. For instance, the program below: 

```
def @func2(%x, %y) { # returns tensor
  if (%x) {
    add(%x, %y)
  } else {
    func1(%y)
  }
}
```
would be rewritten to:
```
def @func2(%x, %y, %state) {  # returns (tensor, state) for both branches
  if (%x) {
    (add(%x, %y), %state) # the original state is also returned
  } else {
    func1_rewritten(%y, %state) # returns the new state
  }
}
```
Since the pass requires evaluations within each scope, it would be easier to 
implement the pass after the program is already transformed to the bblock form. 

### discussions

what type do we use for the random state? 
  - option a: use the empty tuple type. The runtime actually uses the global 
state, and it relies on the deterministic execution order of the program to 
ensure reproducibility. 
  - option b: add a new type (e.g. TypeRandState), and the random state Object 
actually carries the data structure used for generating random numbers (e.g. 
`std::mt19937`). The state is passed around in the program, and invoking an 
operator with the same state object always leads to the same deterministic 
outputs. 

@junrushao1994 @haichen @MarisaKirisame @ziheng would you like to provide some 
suggestions/comments?





---
[Visit 
Topic](https://discuss.tvm.ai/t/rfc-handling-effect-in-tvm-and-relay/5946/22) 
to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/a86d24207aa1a7a2a8338af75e38b548746c7fb2bc863f1801881c09812a82bf).

Reply via email to