Surface Fitting with Kernel Methods¶
The rgpycrumbs.surfaces module provides high-performance, differentiable surface models using JAX. These models are designed for constructing Potential Energy Surfaces (PES) and other multidimensional functions, supporting both standard observations and gradient-enhanced data.
Available Models¶
A variety of kernel-based models are provided via a unified interface:
TPS/RBF: Thin Plate Splines, standard for smooth interpolation.
Matérn 5/2: A stationary kernel with finite differentiability.
IMQ: Inverse Multi-Quadratic kernel, often more stable for large spans.
Gradient-Enhanced: Versions of Matérn, Squared Exponential (SE), RQ, and IMQ that incorporate energy gradients directly into the covariance structure.
Nystrom Gradient IMQ: A memory-efficient approximation for large datasets.
Uncertainty and Variance¶
Every model in surfaces.py implements a predict_var(x_query) method. This calculates the posterior predictive variance, providing a measure of uncertainty:
Interpretation: The variance represents the model’s confidence. At training points (and within the kernel’s length scale), the variance is low. In data-sparse regions, it reverts toward the prior variance (e.g., \(1/\epsilon\) for IMQ or \(1.0\) for Matérn).
Optimization: Parameters like length scales (\(l\), \(\epsilon\)) and noise scalars are optimized via Maximum Likelihood Estimation (MLE) or Maximum A Posteriori (MAP) with physically-informed priors.
Variance Windowing and Chunking¶
To prevent Out-Of-Memory (OOM) errors during the evaluation of large grids (e.g., 2D slice visualizations), both the prediction and variance calls are internally chunked.
model = GradientMatern(x_obs, y_obs, gradients=g_obs)
# Evaluates in chunks of 500 query points by default
z_preds = model(x_grid, chunk_size=1000)
z_vars = model.predict_var(x_grid, chunk_size=1000)
This “windowed” evaluation ensures that the large \(N_{query} \times N_{train}\) cross-covariance matrices do not exhaust system memory.
Numerical Stability¶
Gradient-enhanced models use auto-differentiation (via jax.grad and jax.jacfwd) to construct the full \((D+1)N \times (D+1)N\) covariance matrix.
Cholesky Fallback: The
safe_cholesky_solveutility attempts Cholesky decomposition with increasing jitter to handle near-singular matrices.Float Precision: By default, models use
float32for performance and compatibility with visualization backends, butjax_enable_x64can be toggled if higher precision is required.
Troubleshooting¶
“Out of memory” during surface fitting¶
Problem: JAX runs out of GPU/CPU memory when fitting large datasets.
Solution:
Use Nystrom approximation for datasets with >1000 points:
from rgpycrumbs.surfaces import get_surface_model model = get_surface_model("grad_imq_ny")(x_obs, y_obs, gradients=g_obs, n_inducing=200)
Reduce
n_inducingparameter (default: 100)Use
float32precision instead offloat64
Cholesky decomposition fails¶
Problem: safe_cholesky_solve reports numerical instability.
Solution:
Increase the
smoothingparameter (adds jitter to diagonal):model = GradientMatern(x_obs, y_obs, gradients=g_obs, smoothing=1e-3)Check for duplicate or near-duplicate training points
Normalize input coordinates to [0, 1] range
Variance predictions are all zero¶
Problem: predict_var() returns zeros everywhere.
Solution:
Verify the model was fitted successfully (check
model.lsormodel.epsilon)Ensure query points are not exactly at training points (variance should be
noisethere)Check that
smoothingparameter is not too large
Length scale optimization fails¶
Problem: Optimizer reports “optimization failed” or length scales are extreme.
Solution:
Provide better initial guesses:
model = GradientMatern(x_obs, y_obs, ls=0.5, smoothing=1e-4)Check data scaling - inputs should be O(1) magnitude
Try a different kernel (TPS is more stable for some problems)
Gradient-enhanced model slower than expected¶
Problem: Fitting with gradients is much slower than standard RBF.
Solution:
This is expected - gradient models have \((D+1)N \times (D+1)N\) covariance
Use Nystrom approximation for large datasets
Consider using standard (non-gradient) models if gradients are noisy
API Reference¶
For detailed method signatures, see the Surfaces API Reference.