Surface Fitting with Kernel Methods

The rgpycrumbs.surfaces module provides high-performance, differentiable surface models using JAX. These models are designed for constructing Potential Energy Surfaces (PES) and other multidimensional functions, supporting both standard observations and gradient-enhanced data.

Available Models

A variety of kernel-based models are provided via a unified interface:

  • TPS/RBF: Thin Plate Splines, standard for smooth interpolation.

  • Matérn 5/2: A stationary kernel with finite differentiability.

  • IMQ: Inverse Multi-Quadratic kernel, often more stable for large spans.

  • Gradient-Enhanced: Versions of Matérn, Squared Exponential (SE), RQ, and IMQ that incorporate energy gradients directly into the covariance structure.

  • Nystrom Gradient IMQ: A memory-efficient approximation for large datasets.

Uncertainty and Variance

Every model in surfaces.py implements a predict_var(x_query) method. This calculates the posterior predictive variance, providing a measure of uncertainty:

  • Interpretation: The variance represents the model’s confidence. At training points (and within the kernel’s length scale), the variance is low. In data-sparse regions, it reverts toward the prior variance (e.g., \(1/\epsilon\) for IMQ or \(1.0\) for Matérn).

  • Optimization: Parameters like length scales (\(l\), \(\epsilon\)) and noise scalars are optimized via Maximum Likelihood Estimation (MLE) or Maximum A Posteriori (MAP) with physically-informed priors.

Variance Windowing and Chunking

To prevent Out-Of-Memory (OOM) errors during the evaluation of large grids (e.g., 2D slice visualizations), both the prediction and variance calls are internally chunked.

model = GradientMatern(x_obs, y_obs, gradients=g_obs)
# Evaluates in chunks of 500 query points by default
z_preds = model(x_grid, chunk_size=1000)
z_vars = model.predict_var(x_grid, chunk_size=1000)

This “windowed” evaluation ensures that the large \(N_{query} \times N_{train}\) cross-covariance matrices do not exhaust system memory.

Numerical Stability

Gradient-enhanced models use auto-differentiation (via jax.grad and jax.jacfwd) to construct the full \((D+1)N \times (D+1)N\) covariance matrix.

  • Cholesky Fallback: The safe_cholesky_solve utility attempts Cholesky decomposition with increasing jitter to handle near-singular matrices.

  • Float Precision: By default, models use float32 for performance and compatibility with visualization backends, but jax_enable_x64 can be toggled if higher precision is required.

Troubleshooting

“Out of memory” during surface fitting

Problem: JAX runs out of GPU/CPU memory when fitting large datasets.

Solution:

  1. Use Nystrom approximation for datasets with >1000 points:

    from rgpycrumbs.surfaces import get_surface_model
    model = get_surface_model("grad_imq_ny")(x_obs, y_obs, gradients=g_obs, n_inducing=200)
    
  2. Reduce n_inducing parameter (default: 100)

  3. Use float32 precision instead of float64

Cholesky decomposition fails

Problem: safe_cholesky_solve reports numerical instability.

Solution:

  1. Increase the smoothing parameter (adds jitter to diagonal):

    model = GradientMatern(x_obs, y_obs, gradients=g_obs, smoothing=1e-3)
    
  2. Check for duplicate or near-duplicate training points

  3. Normalize input coordinates to [0, 1] range

Variance predictions are all zero

Problem: predict_var() returns zeros everywhere.

Solution:

  1. Verify the model was fitted successfully (check model.ls or model.epsilon)

  2. Ensure query points are not exactly at training points (variance should be noise there)

  3. Check that smoothing parameter is not too large

Length scale optimization fails

Problem: Optimizer reports “optimization failed” or length scales are extreme.

Solution:

  1. Provide better initial guesses:

    model = GradientMatern(x_obs, y_obs, ls=0.5, smoothing=1e-4)
    
  2. Check data scaling - inputs should be O(1) magnitude

  3. Try a different kernel (TPS is more stable for some problems)

Gradient-enhanced model slower than expected

Problem: Fitting with gradients is much slower than standard RBF.

Solution:

  1. This is expected - gradient models have \((D+1)N \times (D+1)N\) covariance

  2. Use Nystrom approximation for large datasets

  3. Consider using standard (non-gradient) models if gradients are noisy

API Reference

For detailed method signatures, see the Surfaces API Reference.