Greetings.
I hope this email reaches you well. I’m trying to get JAX and PETSc to work
together in a no-copy system using the DLPack tools in both. Unfortunately
I can’t seem to get it to work right. Ideally, I’d like to create a PETSc
vec object using petsc4py, pass it to to a JAX object without copying, make
a change to it in a JAX jitted function and have that change reflected in
the PETSc object. All of this without copying.

Of note: When I try to do this I get an error that the alignment is wrong
and a copy must be made when I call the from-dlpack function but changing
the alignment in the PETSc ./config stage to 32 causes the error message to
disappear, even so it still doesn’t function correctly. I’ve tried looking
through the documentation, but I’m getting a little turned around.

I’ve included a code snippet below:

*from petsc4py import PETSc as PETSc*

*import jax*

*from functools import partial*

*import jax.numpy as jnp*



*@partial(jax.jit, donate_argnums=(0,))*

*def set_in_place(x):*

*    return x.at 
<https://urldefense.us/v3/__http://x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$
 >[:].set(3.0)*



*print('\nTesting jax from_dlpack given a PETSc vector that was allocated
by PETSc')*

*x = jnp.ones((1000,1))*

*y_petsc = PETSc.Vec().createSeq(x.shape[0])*

*y_petsc.set(0.0)*

*print(hex(y_petsc.handle))*

*y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))*

*y2_petsc.set(-1.0)*

*assert y_petsc.getValue(0) == y2_petsc.getValue(0)*

*print('After creating a second PETSc vector via a DLPack of the first,
modifying the memory of one affects the other.')*

*#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)*

*y = jnp.from_dlpack(y_petsc, copy=False)*

*orig_ptr = y.unsafe_buffer_pointer()*

*print(f'before: ptr at {hex(orig_ptr)}')*

*y = set_in_place(y)*

*print(f'after:  ptr at {hex(y.unsafe_buffer_pointer())}')*

*assert orig_ptr == y.unsafe_buffer_pointer()*

*#assert y_petsc.getValue(0) == y[0], f'The PETSc value
{y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the
JAX memory did not affect the PETSc memory.'*


I’d like the bottom two asserts to pass, but I can only get one of them. If
somebody is familiar with this issue I’d greatly appreciate the assistance.

Respectfully:

Alberto

Reply via email to