Hi Alberto, 1. To check the array pointer on the PETSc side, you can do print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch caused by the line y = jnp.from_dlpack(y_petsc, copy=False). This is because you configured PETSc in double precision, but JAX uses single precision by default. You can either add jax.config.update("jax_enable_x64", True) to make JAX use double precision number or configure PETSc to support single precision.
2. Once you fix this precision mismatch, the in-place conversion between PETSc and JAX should work. However, .at[].set() in JAX does not guarantee to operate in-place. The array updates in JAX are generally performed out-of-place by design. You may do the updates in PETSc so that it won’t break the zero-copy system. Hong From: petsc-users <petsc-users-boun...@mcs.anl.gov> on behalf of Alberto Cattaneo <bubu.catta...@gmail.com> Date: Monday, July 7, 2025 at 8:40 AM To: "petsc-users@mcs.anl.gov" <petsc-users@mcs.anl.gov> Subject: [petsc-users] Petsc/Jax no copy interfacing issues Greetings. I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying. Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around. I’ve included a code snippet below: from petsc4py import PETSc as PETSc import jax from functools import partial import jax.numpy as jnp @partial(jax.jit, donate_argnums=(0,)) def set_in_place(x): return x.at<https://urldefense.us/v3/__http:/x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$>[:].set(3.0) print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc') x = jnp.ones((1000,1)) y_petsc = PETSc.Vec().createSeq(x.shape[0]) y_petsc.set(0.0) print(hex(y_petsc.handle)) y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw')) y2_petsc.set(-1.0) assert y_petsc.getValue(0) == y2_petsc.getValue(0) print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.') #y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False) y = jnp.from_dlpack(y_petsc, copy=False) orig_ptr = y.unsafe_buffer_pointer() print(f'before: ptr at {hex(orig_ptr)}') y = set_in_place(y) print(f'after: ptr at {hex(y.unsafe_buffer_pointer())}') assert orig_ptr == y.unsafe_buffer_pointer() #assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.' I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance. Respectfully: Alberto