Greetings. I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying.
Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around. I’ve included a code snippet below: *from petsc4py import PETSc as PETSc* *import jax* *from functools import partial* *import jax.numpy as jnp* *@partial(jax.jit, donate_argnums=(0,))* *def set_in_place(x):* * return x.at <https://urldefense.us/v3/__http://x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$ >[:].set(3.0)* *print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc')* *x = jnp.ones((1000,1))* *y_petsc = PETSc.Vec().createSeq(x.shape[0])* *y_petsc.set(0.0)* *print(hex(y_petsc.handle))* *y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))* *y2_petsc.set(-1.0)* *assert y_petsc.getValue(0) == y2_petsc.getValue(0)* *print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.')* *#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)* *y = jnp.from_dlpack(y_petsc, copy=False)* *orig_ptr = y.unsafe_buffer_pointer()* *print(f'before: ptr at {hex(orig_ptr)}')* *y = set_in_place(y)* *print(f'after: ptr at {hex(y.unsafe_buffer_pointer())}')* *assert orig_ptr == y.unsafe_buffer_pointer()* *#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.'* I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance. Respectfully: Alberto