The main user-facing function for training. Given prepared training and (optionally) validation data, it builds the model, creates physics residual sets, runs the training loop with early stopping, and returns a fitted object for prediction and evaluation.
Usage
fit_swrc(
train_df,
x_inputs,
val_df = NULL,
hidden = c(128L, 64L),
dropout = 0.1,
lr = 0.001,
epochs = 80L,
batch_size = 256L,
patience = 5L,
K = 64L,
lambdas = norouzi_lambdas("norouzi"),
S1 = 1500L,
S2 = 500L,
S3 = 500L,
S4 = 1500L,
pF_lin_min = 5,
pF_lin_max = 7.6,
pF0_pos = 6.2,
pF1_neg = 7.6,
pF_sat_min = -2,
pF_sat_max = -0.3,
wet_split_cm = 4.2,
w_wet = 1,
w_dry = 1,
pf_left = -2,
pf_right = 7.6,
seed = 123L,
verbose = TRUE
)Arguments
- train_df
Data frame for training (output of
prepare_swrc_data()).- x_inputs
Character vector of covariate column names.
- val_df
Optional validation data frame (same structure as
train_df). IfNULL, early stopping is skipped.Integer vector of length 2: Conv1D filter counts (default
c(128L, 64L)).- dropout
Dropout rate (default
0.10).- lr
Learning rate for the Adam optimizer (default
1e-3).- epochs
Maximum number of epochs (default
80).- batch_size
Mini-batch size (default
256).- patience
Early-stopping patience in multiples of 5 epochs (default
5).- K
Number of knot points (default
64L).- lambdas
Named list of loss weights; use
norouzi_lambdas()to generate (default:norouzi_lambdas("norouzi")).- S1, S2, S3, S4
Residual set sizes (defaults: 1500, 500, 500, 1500).
- pF_lin_min
Lower pF for S1 linearity constraint (default
5.0).- pF_lin_max
Upper pF for S1 linearity constraint (default
7.6).- pF0_pos
pF threshold for S2 (default
6.2).- pF1_neg
pF threshold for S3 (default
7.6).- pF_sat_min
Lower pF for S4 (default
-2.0).- pF_sat_max
Upper pF for S4 (default
-0.3).- wet_split_cm
Matric head (cm) separating wet/dry end (default
4.2).- w_wet
Sample weight for wet observations (default
1.0).- w_dry
Sample weight for dry observations (default
1.0).- pf_left
Left pF domain boundary (default
-2.0).- pf_right
Right pF domain boundary (default
7.6).- seed
Random seed (default
123).- verbose
Logical; print progress (default
TRUE).
Value
An S3 object of class swrc_fit, a named list containing:
theta_modelThe fitted Keras model.
param_modelThe theta_s extractor model.
x_inputsCovariate names used.
scalerFitted min-max scaler.
KNumber of knot points.
dkKnot spacing.
knot_gridKnot positions in [0, 1].
pf_left,pf_rightpF domain boundaries.
theta_factorUnit multiplier for theta.
best_epochEpoch at which validation loss was minimised.
lambdasLoss weights used during training.
historyData frame of per-epoch training/validation losses.
Examples
if (FALSE) { # \dontrun{
if (reticulate::py_module_available("tensorflow")) {
df <- prepare_swrc_data(swrc_example, depth_col = "depth")
fit <- fit_swrc(df,
x_inputs = c("clay", "silt", "bd_gcm3", "soc", "Depth_num"),
epochs = 2L, verbose = FALSE)
}
} # }
