API Reference
This is the canonical reference for the exported gsax surface. The package has two workflows:
- Sobol:
sample()->analyze() - RS-HDMR:
analyze_hdmr()->emulate_hdmr()
Related docs:
Exported Surface
Top-level exports from gsax:
Problem Definition
Problem
Immutable dataclass defining parameter names, bounds, and optional output names.
@dataclass(frozen=True)
class Problem:
names: tuple[str, ...]
bounds: tuple[tuple[float, float], ...]
output_names: tuple[str, ...] | None = None| Field / Property | Type | Description |
|---|---|---|
names | tuple[str, ...] | Parameter names in model-input order. |
bounds | tuple[tuple[float, float], ...] | Inclusive lower and upper bound for each parameter. |
output_names | tuple[str, ...] | None | Optional labels for output coordinates in to_dataset(). |
num_vars | int | Property returning len(names). |
Validation and behavior:
Problemis lightweight and does not validate matching lengths or bound ordering on construction.- Keep
namesandboundsaligned manually when instantiating directly. - Prefer
output_nameswhenever results will be exported withto_dataset().
Problem.from_dict()
@classmethod
def from_dict(
cls,
params: dict[str, tuple[float, float]],
output_names: tuple[str, ...] | None = None,
) -> Problemparams keys become names, values become bounds, preserving insertion order.
Minimal example:
import gsax
problem = gsax.Problem.from_dict(
{
"amplitude": (0.5, 2.0),
"frequency": (1.0, 5.0),
"damping": (0.01, 0.5),
},
output_names=("displacement", "velocity"),
)
print(problem.num_vars) # 3Related links:
Sobol Workflow
sample()
Generate a unique Sobol/Saltelli sample matrix for model evaluation.
def sample(
problem: Problem,
n_samples: int,
*,
calc_second_order: bool = True,
scramble: bool = True,
seed: int | np.random.Generator | None = None,
verbose: bool = True,
) -> SamplingResult| Parameter | Type | Default | Description |
|---|---|---|---|
problem | Problem | required | Parameter space definition. |
n_samples | int | required | Minimum desired number of unique model evaluations. |
calc_second_order | bool | True | Include BA blocks so S2 can be computed later. |
scramble | bool | True | Apply Owen scrambling to the Sobol sequence. |
seed | int | np.random.Generator | None | None | Seed or NumPy generator for reproducibility. |
verbose | bool | True | Print a compact sampling summary. |
Returns: SamplingResult
Shape and behavior:
sample()returns unique rows only, not the expanded Saltelli matrix.- The returned sample matrix has shape
(n_total, D). n_samplesis a minimum target, not an exact promise. Internally,base_nis promoted to the next power of 2 and exact duplicate Saltelli rows are removed.- When
calc_second_order=False, later Sobol analysis returnsS2=None.
Minimal example:
import gsax
import jax.numpy as jnp
from gsax.benchmarks.ishigami import PROBLEM, evaluate
sampling_result = gsax.sample(PROBLEM, n_samples=4096, seed=42)
Y = evaluate(jnp.asarray(sampling_result.samples))
result = gsax.analyze(sampling_result, Y)SamplingResult
Immutable dataclass returned by sample(). It carries the unique rows plus the metadata needed for analyze() to reconstruct the internal Saltelli layout.
@dataclass(frozen=True)
class SamplingResult:
samples: np.ndarray
sample_ids: np.ndarray
expanded_n_total: int
expanded_to_unique: np.ndarray
base_n: int
n_params: int
calc_second_order: bool
problem: Problem| Field | Type | Shape / Value | Description |
|---|---|---|---|
samples | np.ndarray | (n_total, D) | Unique rows to evaluate with your model. |
sample_ids | np.ndarray | (n_total,) | Stable integer row IDs aligned with samples. |
expanded_n_total | int | N * step | Expanded Saltelli row count reconstructed internally by analyze(). |
expanded_to_unique | np.ndarray | (expanded_n_total,) | Map from expanded Saltelli rows back to samples. |
base_n | int | power of 2 | Base Sobol sample count. |
n_params | int | D | Number of parameters. |
calc_second_order | bool | Whether BA blocks were included. | |
problem | Problem | Problem used to generate the samples. |
SamplingResult.n_total
Property returning samples.shape[0], i.e. the unique-row count.
SamplingResult.samples_df
Property returning a pandas DataFrame with SampleID followed by one column per parameter. Use it for export, inspection, or joining model outputs back to inputs.
SamplingResult.save()
sampling_result.save("runs/experiment", format="csv")| Parameter | Type | Default | Description |
|---|---|---|---|
path | str | Path | required | File stem with no extension. |
format | str | "csv" | One of csv, txt, xlsx, parquet, or pkl. |
Behavior and validation:
- Writes
path.<format>with the unique rows only. - Writes
path.jsonwith theProblemand Saltelli reconstruction metadata. - Writes
path.npzonly whenexpanded_to_uniqueis not the identity mapping. - Raises
ValueErrorfor unsupported formats. xlsxrequiresopenpyxl;parquetrequirespyarrow.
load()
Reconstruct a saved SamplingResult.
def load(path: str | Path, *, format: str = "csv") -> SamplingResult| Parameter | Type | Default | Description |
|---|---|---|---|
path | str | Path | required | File stem previously passed to save(). |
format | str | "csv" | Must match the format used when saving. |
Validation and behavior:
- Rebuilds
Problem,base_n,expanded_n_total, andexpanded_to_unique. - The sample format is not auto-detected; pass the same
formatexplicitly. - Raises
FileNotFoundErrorif the metadata JSON is missing. - Raises
ValueErrorfor unsupported formats.
Related links:
analyze()
Compute Sobol first-order, total-order, and optional second-order indices from model outputs evaluated on SamplingResult.samples.
def analyze(
sampling_result: SamplingResult,
Y: Array,
*,
prenormalize: bool = False,
num_resamples: int = 0,
conf_level: float = 0.95,
ci_method: Literal["quantile", "gaussian"] = "quantile",
key: Array | None = None,
chunk_size: int = 2048,
) -> SAResult| Parameter | Type | Default | Description |
|---|---|---|---|
sampling_result | SamplingResult | required | Result from sample(). |
Y | Array | required | Model outputs on the unique rows in sampling_result.samples. |
prenormalize | bool | False | Apply SALib-style output standardization over the sample axis before analysis. |
num_resamples | int | 0 | Number of bootstrap resamples. |
conf_level | float | 0.95 | Confidence level for bootstrap intervals. |
ci_method | Literal["quantile", "gaussian"] | "quantile" | Bootstrap CI summary method. quantile returns percentile endpoints; gaussian returns symmetric gaussian endpoints from bootstrap standard deviation. |
key | Array | None | None | Required JAX PRNG key when num_resamples > 0. |
chunk_size | int | 2048 | (T, K) output combinations per batch on the no-bootstrap path. |
Accepted output shapes:
(n_total,)for scalar output(n_total, K)for multi-output(n_total, T, K)for time-series multi-output
Validation and behavior:
- A 2D array is always interpreted as
(N, K), never(N, T). - For a time-series with one output, reshape to
(N, T, 1). - When
prenormalize=True,Yis centered and scaled once per output slice over the sample axis after Saltelli reconstruction and non-finite-group cleanup. ci_methodaccepts"quantile"and"gaussian". The option is ignored whennum_resamples == 0because no CI arrays are produced.- If
num_resamples > 0,keyis required orValueErroris raised. - Sample groups containing any non-finite values are dropped before analysis.
- If every group is invalid,
ValueError("All samples contain non-finite values")is raised. - Zero-variance slices emit warnings because Sobol indices become undefined.
- Bootstrap intervals always remain lower/upper endpoint arrays, not SALib-style half-widths.
ci_method="quantile"uses percentile endpoints, whileci_method="gaussian"uses symmetric gaussian endpoints from bootstrap standard deviation.
Returns: SAResult
SAResult
Dataclass holding Sobol point estimates, optional bootstrap intervals, and diagnostic NaN counts.
@dataclass
class SAResult:
S1: Array
ST: Array
S2: Array | None
problem: Problem
S1_conf: Array | None = None
ST_conf: Array | None = None
S2_conf: Array | None = None
nan_counts: dict[str, int] | None = None| Field | Shape | Description |
|---|---|---|
S1 | (D,) / (K, D) / (T, K, D) | First-order Sobol indices. |
ST | same as S1 | Total-order Sobol indices. |
S2 | (D, D) / (K, D, D) / (T, K, D, D) or None | Symmetric second-order matrix with NaN diagonal. |
S1_conf, ST_conf, S2_conf | (2, ...) or None | Bootstrap lower and upper bounds. |
problem | Problem | Problem carried through for labeling and metadata. |
nan_counts | dict[str, int] | None | Diagnostic NaN counts in the result arrays. |
Shape contract:
Y shape passed to analyze() | S1 / ST | S2 |
|---|---|---|
(N,) | (D,) | (D, D) |
(N, K) | (K, D) | (K, D, D) |
(N, T, K) | (T, K, D) | (T, K, D, D) |
S2 is None when sampling_result.calc_second_order is False. Confidence interval arrays, when present, prepend a leading dimension of 2 for [lower, upper].
SAResult.to_dataset()
ds = result.to_dataset(time_coords=None)Converts Sobol results to a labeled xarray.Dataset.
| Parameter | Type | Default | Description |
|---|---|---|---|
time_coords | list | np.ndarray | None | None | Coordinate values for the time dimension on 3D results. |
Behavior:
- Uses
problem.namesfor parameter coordinates. - Uses
problem.output_nameswhen available, otherwisey0,y1, and so on. - Splits confidence intervals into
*_lowerand*_upperdataset variables. - Uses
param_iandparam_jdimensions forS2.
Minimal example:
import jax
import gsax
from gsax.benchmarks.ishigami import PROBLEM, evaluate
sampling_result = gsax.sample(PROBLEM, n_samples=4096, seed=42)
Y = evaluate(sampling_result.samples)
result = gsax.analyze(
sampling_result,
Y,
prenormalize=True,
num_resamples=200,
key=jax.random.key(0),
)
print(result.S1)
print(result.ST)
print(result.S2 is not None)
print(result.nan_counts)Related links:
RS-HDMR Workflow
analyze_hdmr()
Fit an RS-HDMR surrogate on arbitrary (X, Y) pairs and derive ANCOVA-based sensitivity indices.
def analyze_hdmr(
problem: Problem,
X: Array,
Y: Array,
*,
prenormalize: bool = False,
maxorder: int = 2,
maxiter: int = 100,
m: int = 2,
lambdax: float = 0.01,
chunk_size: int = 2048,
) -> HDMRResult| Parameter | Type | Default | Description |
|---|---|---|---|
problem | Problem | required | Bounds and names used to normalize X. |
X | Array | required | Input array with shape (N, D). |
Y | Array | required | Output array with shape (N,), (N, K), or (N, T, K). |
prenormalize | bool | False | Apply SALib-style output standardization over the sample axis before fitting. |
maxorder | int | 2 | Maximum HDMR expansion order. |
maxiter | int | 100 | Maximum backfitting iterations. |
m | int | 2 | Number of B-spline intervals. |
lambdax | float | 0.01 | Tikhonov regularization strength. |
chunk_size | int | 2048 | Maximum (T, K) combinations per batch. |
Validation and behavior:
X.shape[1]must matchproblem.num_vars.- At least 300 rows are required or
ValueErroris raised. maxordermust be 1, 2, or 3.- When
D == 2,maxordercannot exceed 2. chunk_sizemust be at least 1.- A 2D output array is always treated as
(N, K). - When
prenormalize=True,Yis centered and scaled once per output slice over the sample axis before surrogate fitting.
Returns: HDMRResult
emulate_hdmr()
Predict at new input points using the surrogate stored in an HDMRResult.
def emulate_hdmr(result: HDMRResult, X_new: Array) -> Array| Parameter | Type | Description |
|---|---|---|
result | HDMRResult | Must contain emulator. |
X_new | Array | New input points with shape (N_new, D). |
Validation and behavior:
- Raises
ValueErrorwhenresult.emulator is None. - Returns
(N_new,),(N_new, K), or(N_new, T, K)to match the fitted output layout. - When the result was fit with
prenormalize=True, predictions are mapped back to the original output scale before being returned. - Not JIT-compatible because
HDMRResultis not a JAX pytree.
HDMRResult
Dataclass holding ANCOVA-decomposed HDMR sensitivities and optional emulator artifacts.
@dataclass
class HDMRResult:
Sa: Array
Sb: Array
S: Array
ST: Array
problem: Problem
terms: tuple[str, ...]
emulator: HDMREmulator | None = None
select: Array | None = None
rmse: Array | None = None| Field | Shape | Description |
|---|---|---|
Sa | (n_terms,) / (K, n_terms) / (T, K, n_terms) | Structural contribution per term. |
Sb | same as Sa | Correlative contribution per term. |
S | same as Sa | Total contribution per term: Sa + Sb. |
ST | (D,) / (K, D) / (T, K, D) | Total contribution per parameter. |
terms | tuple[str, ...] | Human-readable term labels such as "x1/x2". |
emulator | HDMREmulator | None | Surrogate coefficients and static metadata. |
select | (n_terms,) or None | F-test selection counts summed across outputs. |
rmse | () / (K,) / (T, K) or None | Emulator RMSE without the sample axis. |
HDMRResult.S1
Property returning the first-order structural contribution extracted from the first D HDMR terms:
hdmr.S1 # shape matches hdmr.STThis is the Sobol-compatible first-order view of an HDMR fit.
HDMRResult.to_dataset()
ds = hdmr.to_dataset(time_coords=None)Converts HDMR results to a labeled xarray.Dataset.
Behavior:
- Uses
termforSa,Sb,S, andselect. - Uses
paramforST. - Uses
problem.output_nameswhen available, otherwise generated labels. - Uses
time_coordswhen passed for 3D results.
HDMREmulator
Typed dictionary stored on HDMRResult.emulator.
class HDMREmulator(TypedDict):
C1: Array
C2: Array | None
C3: Array | None
f0: Array
prenormalize: bool
y_mean: Array
y_std: Array
m: int
maxorder: int
c2: list[tuple[int, int]]
c3: list[tuple[int, int, int]]| Key | Description |
|---|---|
C1, C2, C3 | Fitted B-spline coefficients for first-, second-, and third-order terms. |
f0 | Intercept term in the emulator. |
prenormalize | Whether the HDMR fit standardized outputs before fitting. |
y_mean, y_std | Per-output-slice statistics used to map prenormalized predictions back to the original scale. |
m | Number of spline intervals used during fitting. |
maxorder | Expansion order used to build the surrogate. |
c2, c3 | Term-index mappings for pairwise and triple interaction terms. |
Minimal example:
import jax
import jax.numpy as jnp
import gsax
from gsax.benchmarks.ishigami import PROBLEM, evaluate
key = jax.random.PRNGKey(42)
bounds = jnp.array(PROBLEM.bounds)
X = jax.random.uniform(key, (2000, PROBLEM.num_vars), minval=bounds[:, 0], maxval=bounds[:, 1])
Y = evaluate(X)
hdmr = gsax.analyze_hdmr(PROBLEM, X, Y, maxorder=2)
Y_pred = gsax.emulate_hdmr(hdmr, X[:5])
print(hdmr.S1)
print(hdmr.ST)
print(Y_pred.shape)Related links: