Hi Alberto,

1. To check the array pointer on the PETSc side, you can do 
print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch 
caused by the line y = jnp.from_dlpack(y_petsc, copy=False). This is because 
you configured PETSc in double precision, but JAX uses single precision by 
default. You can either add jax.config.update("jax_enable_x64", True) to make 
JAX use double precision number or configure PETSc to support single precision.

2. Once you fix this precision mismatch, the in-place conversion between PETSc 
and JAX should work. However, .at[].set() in JAX does not guarantee to operate 
in-place. The array updates in JAX are generally performed out-of-place by 
design. You may do the updates in PETSc so that it won’t break the zero-copy 
system.
Hong

From: petsc-users <petsc-users-boun...@mcs.anl.gov> on behalf of Alberto 
Cattaneo <bubu.catta...@gmail.com>
Date: Monday, July 7, 2025 at 8:40 AM
To: "petsc-users@mcs.anl.gov" <petsc-users@mcs.anl.gov>
Subject: [petsc-users] Petsc/Jax no copy interfacing issues

Greetings.
I hope this email reaches you well. I’m trying to get JAX and PETSc to work 
together in a no-copy system using the DLPack tools in both. Unfortunately I 
can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec 
object using petsc4py, pass it to to a JAX object without copying, make a 
change to it in a JAX jitted function and have that change reflected in the 
PETSc object. All of this without copying.

Of note: When I try to do this I get an error that the alignment is wrong and a 
copy must be made when I call the from-dlpack function but changing the 
alignment in the PETSc ./config stage to 32 causes the error message to 
disappear, even so it still doesn’t function correctly. I’ve tried looking 
through the documentation, but I’m getting a little turned around.

I’ve included a code snippet below:


from petsc4py import PETSc as PETSc

import jax

from functools import partial

import jax.numpy as jnp



@partial(jax.jit, donate_argnums=(0,))

def set_in_place(x):

    return 
x.at<https://urldefense.us/v3/__http:/x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$>[:].set(3.0)



print('\nTesting jax from_dlpack given a PETSc vector that was allocated by 
PETSc')

x = jnp.ones((1000,1))

y_petsc = PETSc.Vec().createSeq(x.shape[0])

y_petsc.set(0.0)

print(hex(y_petsc.handle))

y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))

y2_petsc.set(-1.0)

assert y_petsc.getValue(0) == y2_petsc.getValue(0)

print('After creating a second PETSc vector via a DLPack of the first, 
modifying the memory of one affects the other.')

#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)

y = jnp.from_dlpack(y_petsc, copy=False)

orig_ptr = y.unsafe_buffer_pointer()

print(f'before: ptr at {hex(orig_ptr)}')

y = set_in_place(y)

print(f'after:  ptr at {hex(y.unsafe_buffer_pointer())}')

assert orig_ptr == y.unsafe_buffer_pointer()

#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} 
did not match the JAX value {y[0]}, so modifying the JAX memory did not affect 
the PETSc memory.'



I’d like the bottom two asserts to pass, but I can only get one of them. If 
somebody is familiar with this issue I’d greatly appreciate the assistance.

Respectfully:

Alberto

Reply via email to