Skip to content

API Reference

This is the canonical reference for the exported gsax surface. The package has two workflows:

  • Sobol: sample() -> analyze()
  • RS-HDMR: analyze_hdmr() -> emulate_hdmr()

Related docs:

Exported Surface

Top-level exports from gsax:

Problem Definition

Problem

Immutable dataclass defining parameter names, bounds, and optional output names.

python
@dataclass(frozen=True)
class Problem:
    names: tuple[str, ...]
    bounds: tuple[tuple[float, float], ...]
    output_names: tuple[str, ...] | None = None
Field / PropertyTypeDescription
namestuple[str, ...]Parameter names in model-input order.
boundstuple[tuple[float, float], ...]Inclusive lower and upper bound for each parameter.
output_namestuple[str, ...] | NoneOptional labels for output coordinates in to_dataset().
num_varsintProperty returning len(names).

Validation and behavior:

  • Problem is lightweight and does not validate matching lengths or bound ordering on construction.
  • Keep names and bounds aligned manually when instantiating directly.
  • Prefer output_names whenever results will be exported with to_dataset().

Problem.from_dict()

python
@classmethod
def from_dict(
    cls,
    params: dict[str, tuple[float, float]],
    output_names: tuple[str, ...] | None = None,
) -> Problem

params keys become names, values become bounds, preserving insertion order.

Minimal example:

python
import gsax

problem = gsax.Problem.from_dict(
    {
        "amplitude": (0.5, 2.0),
        "frequency": (1.0, 5.0),
        "damping": (0.01, 0.5),
    },
    output_names=("displacement", "velocity"),
)

print(problem.num_vars)  # 3

Related links:

Sobol Workflow

sample()

Generate a unique Sobol/Saltelli sample matrix for model evaluation.

python
def sample(
    problem: Problem,
    n_samples: int,
    *,
    calc_second_order: bool = True,
    scramble: bool = True,
    seed: int | np.random.Generator | None = None,
    verbose: bool = True,
) -> SamplingResult
ParameterTypeDefaultDescription
problemProblemrequiredParameter space definition.
n_samplesintrequiredMinimum desired number of unique model evaluations.
calc_second_orderboolTrueInclude BA blocks so S2 can be computed later.
scrambleboolTrueApply Owen scrambling to the Sobol sequence.
seedint | np.random.Generator | NoneNoneSeed or NumPy generator for reproducibility.
verboseboolTruePrint a compact sampling summary.

Returns: SamplingResult

Shape and behavior:

  • sample() returns unique rows only, not the expanded Saltelli matrix.
  • The returned sample matrix has shape (n_total, D).
  • n_samples is a minimum target, not an exact promise. Internally, base_n is promoted to the next power of 2 and exact duplicate Saltelli rows are removed.
  • When calc_second_order=False, later Sobol analysis returns S2=None.

Minimal example:

python
import gsax
import jax.numpy as jnp
from gsax.benchmarks.ishigami import PROBLEM, evaluate

sampling_result = gsax.sample(PROBLEM, n_samples=4096, seed=42)
Y = evaluate(jnp.asarray(sampling_result.samples))
result = gsax.analyze(sampling_result, Y)

SamplingResult

Immutable dataclass returned by sample(). It carries the unique rows plus the metadata needed for analyze() to reconstruct the internal Saltelli layout.

python
@dataclass(frozen=True)
class SamplingResult:
    samples: np.ndarray
    sample_ids: np.ndarray
    expanded_n_total: int
    expanded_to_unique: np.ndarray
    base_n: int
    n_params: int
    calc_second_order: bool
    problem: Problem
FieldTypeShape / ValueDescription
samplesnp.ndarray(n_total, D)Unique rows to evaluate with your model.
sample_idsnp.ndarray(n_total,)Stable integer row IDs aligned with samples.
expanded_n_totalintN * stepExpanded Saltelli row count reconstructed internally by analyze().
expanded_to_uniquenp.ndarray(expanded_n_total,)Map from expanded Saltelli rows back to samples.
base_nintpower of 2Base Sobol sample count.
n_paramsintDNumber of parameters.
calc_second_orderboolWhether BA blocks were included.
problemProblemProblem used to generate the samples.

SamplingResult.n_total

Property returning samples.shape[0], i.e. the unique-row count.

SamplingResult.samples_df

Property returning a pandas DataFrame with SampleID followed by one column per parameter. Use it for export, inspection, or joining model outputs back to inputs.

SamplingResult.save()

python
sampling_result.save("runs/experiment", format="csv")
ParameterTypeDefaultDescription
pathstr | PathrequiredFile stem with no extension.
formatstr"csv"One of csv, txt, xlsx, parquet, or pkl.

Behavior and validation:

  • Writes path.<format> with the unique rows only.
  • Writes path.json with the Problem and Saltelli reconstruction metadata.
  • Writes path.npz only when expanded_to_unique is not the identity mapping.
  • Raises ValueError for unsupported formats.
  • xlsx requires openpyxl; parquet requires pyarrow.

load()

Reconstruct a saved SamplingResult.

python
def load(path: str | Path, *, format: str = "csv") -> SamplingResult
ParameterTypeDefaultDescription
pathstr | PathrequiredFile stem previously passed to save().
formatstr"csv"Must match the format used when saving.

Validation and behavior:

  • Rebuilds Problem, base_n, expanded_n_total, and expanded_to_unique.
  • The sample format is not auto-detected; pass the same format explicitly.
  • Raises FileNotFoundError if the metadata JSON is missing.
  • Raises ValueError for unsupported formats.

Related links:

analyze()

Compute Sobol first-order, total-order, and optional second-order indices from model outputs evaluated on SamplingResult.samples.

python
def analyze(
    sampling_result: SamplingResult,
    Y: Array,
    *,
    prenormalize: bool = False,
    num_resamples: int = 0,
    conf_level: float = 0.95,
    ci_method: Literal["quantile", "gaussian"] = "quantile",
    key: Array | None = None,
    chunk_size: int = 2048,
) -> SAResult
ParameterTypeDefaultDescription
sampling_resultSamplingResultrequiredResult from sample().
YArrayrequiredModel outputs on the unique rows in sampling_result.samples.
prenormalizeboolFalseApply SALib-style output standardization over the sample axis before analysis.
num_resamplesint0Number of bootstrap resamples.
conf_levelfloat0.95Confidence level for bootstrap intervals.
ci_methodLiteral["quantile", "gaussian"]"quantile"Bootstrap CI summary method. quantile returns percentile endpoints; gaussian returns symmetric gaussian endpoints from bootstrap standard deviation.
keyArray | NoneNoneRequired JAX PRNG key when num_resamples > 0.
chunk_sizeint2048(T, K) output combinations per batch on the no-bootstrap path.

Accepted output shapes:

  • (n_total,) for scalar output
  • (n_total, K) for multi-output
  • (n_total, T, K) for time-series multi-output

Validation and behavior:

  • A 2D array is always interpreted as (N, K), never (N, T).
  • For a time-series with one output, reshape to (N, T, 1).
  • When prenormalize=True, Y is centered and scaled once per output slice over the sample axis after Saltelli reconstruction and non-finite-group cleanup.
  • ci_method accepts "quantile" and "gaussian". The option is ignored when num_resamples == 0 because no CI arrays are produced.
  • If num_resamples > 0, key is required or ValueError is raised.
  • Sample groups containing any non-finite values are dropped before analysis.
  • If every group is invalid, ValueError("All samples contain non-finite values") is raised.
  • Zero-variance slices emit warnings because Sobol indices become undefined.
  • Bootstrap intervals always remain lower/upper endpoint arrays, not SALib-style half-widths. ci_method="quantile" uses percentile endpoints, while ci_method="gaussian" uses symmetric gaussian endpoints from bootstrap standard deviation.

Returns: SAResult

SAResult

Dataclass holding Sobol point estimates, optional bootstrap intervals, and diagnostic NaN counts.

python
@dataclass
class SAResult:
    S1: Array
    ST: Array
    S2: Array | None
    problem: Problem
    S1_conf: Array | None = None
    ST_conf: Array | None = None
    S2_conf: Array | None = None
    nan_counts: dict[str, int] | None = None
FieldShapeDescription
S1(D,) / (K, D) / (T, K, D)First-order Sobol indices.
STsame as S1Total-order Sobol indices.
S2(D, D) / (K, D, D) / (T, K, D, D) or NoneSymmetric second-order matrix with NaN diagonal.
S1_conf, ST_conf, S2_conf(2, ...) or NoneBootstrap lower and upper bounds.
problemProblemProblem carried through for labeling and metadata.
nan_countsdict[str, int] | NoneDiagnostic NaN counts in the result arrays.

Shape contract:

Y shape passed to analyze()S1 / STS2
(N,)(D,)(D, D)
(N, K)(K, D)(K, D, D)
(N, T, K)(T, K, D)(T, K, D, D)

S2 is None when sampling_result.calc_second_order is False. Confidence interval arrays, when present, prepend a leading dimension of 2 for [lower, upper].

SAResult.to_dataset()

python
ds = result.to_dataset(time_coords=None)

Converts Sobol results to a labeled xarray.Dataset.

ParameterTypeDefaultDescription
time_coordslist | np.ndarray | NoneNoneCoordinate values for the time dimension on 3D results.

Behavior:

  • Uses problem.names for parameter coordinates.
  • Uses problem.output_names when available, otherwise y0, y1, and so on.
  • Splits confidence intervals into *_lower and *_upper dataset variables.
  • Uses param_i and param_j dimensions for S2.

Minimal example:

python
import jax
import gsax
from gsax.benchmarks.ishigami import PROBLEM, evaluate

sampling_result = gsax.sample(PROBLEM, n_samples=4096, seed=42)
Y = evaluate(sampling_result.samples)
result = gsax.analyze(
    sampling_result,
    Y,
    prenormalize=True,
    num_resamples=200,
    key=jax.random.key(0),
)

print(result.S1)
print(result.ST)
print(result.S2 is not None)
print(result.nan_counts)

Related links:

RS-HDMR Workflow

analyze_hdmr()

Fit an RS-HDMR surrogate on arbitrary (X, Y) pairs and derive ANCOVA-based sensitivity indices.

python
def analyze_hdmr(
    problem: Problem,
    X: Array,
    Y: Array,
    *,
    prenormalize: bool = False,
    maxorder: int = 2,
    maxiter: int = 100,
    m: int = 2,
    lambdax: float = 0.01,
    chunk_size: int = 2048,
) -> HDMRResult
ParameterTypeDefaultDescription
problemProblemrequiredBounds and names used to normalize X.
XArrayrequiredInput array with shape (N, D).
YArrayrequiredOutput array with shape (N,), (N, K), or (N, T, K).
prenormalizeboolFalseApply SALib-style output standardization over the sample axis before fitting.
maxorderint2Maximum HDMR expansion order.
maxiterint100Maximum backfitting iterations.
mint2Number of B-spline intervals.
lambdaxfloat0.01Tikhonov regularization strength.
chunk_sizeint2048Maximum (T, K) combinations per batch.

Validation and behavior:

  • X.shape[1] must match problem.num_vars.
  • At least 300 rows are required or ValueError is raised.
  • maxorder must be 1, 2, or 3.
  • When D == 2, maxorder cannot exceed 2.
  • chunk_size must be at least 1.
  • A 2D output array is always treated as (N, K).
  • When prenormalize=True, Y is centered and scaled once per output slice over the sample axis before surrogate fitting.

Returns: HDMRResult

emulate_hdmr()

Predict at new input points using the surrogate stored in an HDMRResult.

python
def emulate_hdmr(result: HDMRResult, X_new: Array) -> Array
ParameterTypeDescription
resultHDMRResultMust contain emulator.
X_newArrayNew input points with shape (N_new, D).

Validation and behavior:

  • Raises ValueError when result.emulator is None.
  • Returns (N_new,), (N_new, K), or (N_new, T, K) to match the fitted output layout.
  • When the result was fit with prenormalize=True, predictions are mapped back to the original output scale before being returned.
  • Not JIT-compatible because HDMRResult is not a JAX pytree.

HDMRResult

Dataclass holding ANCOVA-decomposed HDMR sensitivities and optional emulator artifacts.

python
@dataclass
class HDMRResult:
    Sa: Array
    Sb: Array
    S: Array
    ST: Array
    problem: Problem
    terms: tuple[str, ...]
    emulator: HDMREmulator | None = None
    select: Array | None = None
    rmse: Array | None = None
FieldShapeDescription
Sa(n_terms,) / (K, n_terms) / (T, K, n_terms)Structural contribution per term.
Sbsame as SaCorrelative contribution per term.
Ssame as SaTotal contribution per term: Sa + Sb.
ST(D,) / (K, D) / (T, K, D)Total contribution per parameter.
termstuple[str, ...]Human-readable term labels such as "x1/x2".
emulatorHDMREmulator | NoneSurrogate coefficients and static metadata.
select(n_terms,) or NoneF-test selection counts summed across outputs.
rmse() / (K,) / (T, K) or NoneEmulator RMSE without the sample axis.

HDMRResult.S1

Property returning the first-order structural contribution extracted from the first D HDMR terms:

python
hdmr.S1  # shape matches hdmr.ST

This is the Sobol-compatible first-order view of an HDMR fit.

HDMRResult.to_dataset()

python
ds = hdmr.to_dataset(time_coords=None)

Converts HDMR results to a labeled xarray.Dataset.

Behavior:

  • Uses term for Sa, Sb, S, and select.
  • Uses param for ST.
  • Uses problem.output_names when available, otherwise generated labels.
  • Uses time_coords when passed for 3D results.

HDMREmulator

Typed dictionary stored on HDMRResult.emulator.

python
class HDMREmulator(TypedDict):
    C1: Array
    C2: Array | None
    C3: Array | None
    f0: Array
    prenormalize: bool
    y_mean: Array
    y_std: Array
    m: int
    maxorder: int
    c2: list[tuple[int, int]]
    c3: list[tuple[int, int, int]]
KeyDescription
C1, C2, C3Fitted B-spline coefficients for first-, second-, and third-order terms.
f0Intercept term in the emulator.
prenormalizeWhether the HDMR fit standardized outputs before fitting.
y_mean, y_stdPer-output-slice statistics used to map prenormalized predictions back to the original scale.
mNumber of spline intervals used during fitting.
maxorderExpansion order used to build the surrogate.
c2, c3Term-index mappings for pairwise and triple interaction terms.

Minimal example:

python
import jax
import jax.numpy as jnp
import gsax
from gsax.benchmarks.ishigami import PROBLEM, evaluate

key = jax.random.PRNGKey(42)
bounds = jnp.array(PROBLEM.bounds)
X = jax.random.uniform(key, (2000, PROBLEM.num_vars), minval=bounds[:, 0], maxval=bounds[:, 1])
Y = evaluate(X)

hdmr = gsax.analyze_hdmr(PROBLEM, X, Y, maxorder=2)
Y_pred = gsax.emulate_hdmr(hdmr, X[:5])

print(hdmr.S1)
print(hdmr.ST)
print(Y_pred.shape)

Related links:

Released under the MIT License.