Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions essos/custom_dipole_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import jax.numpy as jnp
from simsopt.geo import SurfaceRZFourier
from simsopt.field.magneticfieldclasses import DipoleField as SimsoptDipoleField
from jax import jit, vmap
from functools import partial
import time
import jax
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from essos.util import read_famus_dipoles
from essos.fields import DipoleField
jax.config.update('jax_enable_x64', True)

def compare_dipole_fields(surface_file, famus_file,data, output_dir="output", plot=False, nphi=16, ntheta=16):
"""Compare SIMSOPT and custom dipole field calculations."""
positions, moments, Ic, pho = data
mask = (Ic == 1)
positions = positions[mask]
moments = moments[mask]
pho = pho[mask]
s_plot = SurfaceRZFourier.from_focus(surface_file, quadpoints_phi=jnp.linspace(0, 1, nphi), quadpoints_theta=jnp.linspace(0, 1, ntheta))
gamma = s_plot.gamma().reshape((-1, 3))
unitnormal = s_plot.unitnormal().reshape((-1, 3))
if positions.size == 0:
print("No dipoles found in famus_file")
field_simsopt = None
else:
if positions.ndim == 1:
positions = positions.reshape(-1, 3)
if moments.ndim == 1:
moments = moments.reshape(-1, 3)
field_simsopt = SimsoptDipoleField(positions, moments, stellsym=True, nfp=s_plot.nfp)
if field_simsopt is not None:
field_simsopt.set_points(gamma)
start = time.time()
B_simsopt = field_simsopt.B().reshape((-1, 3))
simsopt_time = time.time() - start
else:
B_simsopt = None
simsopt_time = 0.0
field_essos = DipoleField(positions, moments, pho, stellsym=True, nfp=s_plot.nfp)
start = time.time()
B_essos = field_essos.B(gamma)
essos_time = time.time() - start
Bnormal_simsopt = jnp.sum(B_simsopt * unitnormal, axis=1).reshape((nphi, ntheta)) if field_simsopt is not None else jnp.zeros((nphi, ntheta))
Bnormal_essos = jnp.sum(B_essos * unitnormal, axis=1).reshape((nphi, ntheta))
diff_Bn = Bnormal_essos - Bnormal_simsopt
max_diff_Bn = jnp.max(jnp.abs(diff_Bn))
mean_diff_Bn = jnp.mean(jnp.abs(diff_Bn))
print(f"Max |ΔB·n|: {max_diff_Bn}")
print(f"Mean |ΔB·n|: {mean_diff_Bn}")
if plot:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
titles = ['SIMSOPT B·n', 'ESSOS B·n', 'ΔB·n']
data_sets = [Bnormal_simsopt, Bnormal_essos, diff_Bn]
phi_grid = jnp.linspace(0, 1, nphi)
theta_grid = jnp.linspace(0, 1, ntheta)
vmin = min(d.min() for d in data_sets if d.size > 0)
vmax = max(d.max() for d in data_sets if d.size > 0)
ims = []
for ax, data, title in zip(axes, data_sets, titles):
im = ax.contourf(phi_grid, theta_grid, data.T, levels=20, vmin=vmin, vmax=vmax, cmap='viridis')
ax.set_xlabel('Phi')
ax.set_ylabel('Theta')
ax.set_title(title)
ims.append(im)
#ax.colorbar(im,orientation='horizontal', label='B·n (T)')
fig.colorbar(im, ax=ax, orientation='horizontal',
fraction=0.046, pad=0.25, label='B·n (T)')


#plt.subplots_adjust(bottom=0.35, wspace=0.3, left=0.05, right=0.95)
#cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.05])
#cbar = fig.colorbar(ims[0], cax=cbar_ax, orientation='horizontal', label='B·n (T)')
os.makedirs(output_dir, exist_ok=True)
plt.savefig(os.path.join(output_dir, 'b_n_plot.png'), bbox_inches='tight', dpi=150)
plt.close(fig)
return field_essos, s_plot, gamma, unitnormal, essos_time, simsopt_time
68 changes: 68 additions & 0 deletions essos/examples/compute_G_symmetric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax
import jax.numpy as jnp
import numpy as np

MU0_4PI = 1e-7

def _bn_one_copy_pmap(surf_pts, surf_n, mag_pos, mag_mom):
"""Bn at surf_pts from magnets. Uses pmap for parallelism."""
n_devices = jax.device_count()
n_points = len(surf_pts)
remainder = n_points % n_devices
if remainder != 0:
pad = n_devices - remainder
surf_pts = jnp.concatenate([surf_pts, jnp.zeros((pad, 3), surf_pts.dtype)])
surf_n = jnp.concatenate([surf_n, jnp.zeros((pad, 3), surf_n.dtype)])

batch = len(surf_pts) // n_devices
pts_s = surf_pts.reshape(n_devices, batch, 3)
n_s = surf_n.reshape(n_devices, batch, 3)

m_pos = jnp.array(mag_pos)
m_mom = jnp.array(mag_mom)

def kernel(pts, norms):
P = jnp.expand_dims(pts, 1)
M_pos = jnp.expand_dims(m_pos, 0)
M_vec = jnp.expand_dims(m_mom, 0)
N = jnp.expand_dims(norms, 1)
R = P - M_pos
R_mag = jnp.linalg.norm(R, axis=2, keepdims=True)
dot_mr = jnp.sum(M_vec * R, axis=2, keepdims=True)
dot_rn = jnp.sum(R * N, axis=2, keepdims=True)
dot_mn = jnp.sum(M_vec * N, axis=2, keepdims=True)
term1 = 3.0 * dot_mr * dot_rn / (R_mag**5 + 1e-30)
term2 = -dot_mn / (R_mag**3 + 1e-30)
return jnp.squeeze((term1 + term2) * MU0_4PI, axis=2)

pts_d = jax.device_put_sharded(list(pts_s), jax.local_devices())
n_d = jax.device_put_sharded(list(n_s), jax.local_devices())
G_s = jax.pmap(kernel)(pts_d, n_d)
G_s.block_until_ready()
G_full = G_s.reshape(-1, len(m_pos))
if remainder != 0:
G_full = G_full[:n_points]
return G_full

def compute_G_symmetric(positions, moments, surf_pts, surf_n, nfp=2, stellsym=True):
"""
Build G (n_surf, n_mag) summing contributions from all symmetric copies.
Uses pmap for fast parallel computation.
"""
n_surf = len(surf_pts)
n_mag = len(positions)
G = jnp.zeros((n_surf, n_mag), dtype=jnp.float32)

stell_list = [1.0, -1.0] if stellsym else [1.0]
for stell in stell_list:
pos_s = positions * jnp.array([1.0, stell, stell])
mom_s = moments * jnp.array([stell, 1.0, 1.0]) if stellsym else moments
for i in range(nfp):
angle = 2 * jnp.pi * i / nfp
c, s = jnp.cos(angle), jnp.sin(angle)
R_mat = jnp.array([[c, -s, 0.0], [s, c, 0.0], [0., 0., 1.0]])
pos_r = pos_s @ R_mat.T
mom_r = mom_s @ R_mat.T
G = G + _bn_one_copy_pmap(surf_pts, surf_n, pos_r, mom_r)

return G
Binary file not shown.
Loading