2D Function Approximation using Torch (via targets in R)

Author

ELC 5365: Deep Learning | Baylor University

1. Introduction and Setup

This document replicates the 05_FuncApprox_2D_PyTorch.ipynb notebook using R, the torch package, and a targets build pipeline.

Instead of defining and training the neural networks sequentially in memory, we rely entirely on the _targets.R declarative pipeline to cache the outputs and run our computational graph.

library(targets)

2. Define the Target Function

We defined the 2D function we want our neural network to learn in our R/functions.R file: \[f(x_1, x_2) = \sin\left(\frac{\pi}{2} x_1\right) \cos\left(\frac{\pi}{4} x_2\right)\] where \(x_1, x_2 \in [-2, 2]\).

3. Visualize the Original Function

# Pulling the contour plots from the baseline target evaluated grid
plots <- tar_read(plot_contours)
plots$original

4. Generate Training Data

We generated a uniform random grid of \(N=5000\) points within the domain and saved it to the train_data target.

train_data <- tar_read(train_data)
dim(train_data$X)
[1] 5000    2
dim(train_data$Y)
[1] 5000    1

5. Define the Neural Network

We built a multi-layer perceptron module FuncApproxNet(nn_module). By default, it features two hidden layers with 200 neurons each and ReLU activation functions.

6. Train the Model & 7. Plot Training Loss

We use Mean Squared Error (MSE) loss and the Adam optimizer. The train_model function in our pipeline automates this loop and tracks the loss over epochs, as well as tracking the overall computation time.

tar_read(plot_loss)

8. Evaluate and Visualize the Approximation

We compare the original function with the neural network’s approximation using our plotly-based contour and 3D surface combinations.

# Note the combined interactive plot with Title/Error
plots <- tar_read(plot_contours)
htmltools::tagList(plots)

9. Experiment: Approximate a Different Function

We also approximated a custom “Function 3” (combining sine and exponential decay terms) using the same baseline architecture:

plots_f3 <- tar_read(plot_contours_f3)
htmltools::tagList(plots_f3)

10. Exercises

1. Change the number of hidden neurons

Try n_hidden = 50 and n_hidden = 400. How does it affect the approximation quality?

# 50 Neurons & 400 Neurons
plots_h50 <- tar_read(plot_contours_h50)
plots_h400 <- tar_read(plot_contours_h400)
htmltools::tagList(plots_h50, plots_h400)

Conclusion: Decreasing the neurons typically limits the model’s capacity to represent complex curves, raising the mean error. Increasing nodes improves the fit, but takes significantly longer to train (as measured by the Training Time).

2. Add more hidden layers

Modify FuncApproxNet to have 3 or 4 hidden layers. Does deeper always mean better?

# 3 Layers & 4 Layers
plots_l3 <- tar_read(plot_contours_l3)
plots_l4 <- tar_read(plot_contours_l4)
htmltools::tagList(plots_l3, plots_l4)

Conclusion: Deeper networks don’t necessarily guarantee a better fit for purely smooth, low-frequency 2D functions like this one. They take longer to compute and evaluate and risk getting stuck in local minima or experiencing vanishing gradients if not regularized.

3. Try different activation functions

Replace nn.ReLU() with nn.Tanh() or nn.LeakyReLU(). Compare results.

# Tanh & LeakyReLU
plots_tanh <- tar_read(plot_contours_tanh)
plots_leaky <- tar_read(plot_contours_leaky)
htmltools::tagList(plots_tanh, plots_leaky)

Conclusion: Tanh provides smooth continuously differentiable surfaces (since the activation itself is smooth), often perfectly mirroring sinusoidal targets like ours. LeakyReLU prevents dead neurons and typically results in a slightly faceted but robust convergence similar to ReLU.

4. Change the learning rate

Try lr = 0.01 and lr = 0.0001. How does it affect training speed and final accuracy?

# lr = 0.01 & 0.0001
plots_lr01 <- tar_read(plot_contours_lr01)
plots_lr0001 <- tar_read(plot_contours_lr0001)
htmltools::tagList(plots_lr01, plots_lr0001)

Conclusion: A large learning rate (0.01) can overshoot the optimal parameters rapidly, often plateauing in a suboptimal fit (high error). A small learning rate (0.0001) guarantees stable, smooth descent, but typically requires vastly more epochs to fully converge. Because we kept epochs fixed at 2000, 0.0001 usually results in an incomplete fit.

5. Define your own function

Create a new target function and train the network to approximate it.

For this target, we modeled a complex bivariate skew-normal mixture density. Note that the domain for this function spans \(x_1, x_2 \in [-4, 8]\) rather than \([-2, 2]\).

The probability density function for a bivariate skew-normal mixture with \(K\) components is defined as:

\[ f(x) = \sum_{k=1}^K w_k \cdot 2 \phi_2(x; \xi_k, \Omega_k) \Phi_1(\alpha_k^T \Omega_k^{-1/2} (x - \xi_k)) \]

where: - \(x = (x_1, x_2)^T\) - \(w_k\) are the mixing weights (such that \(\sum w_k = 1\)). - \(\phi_2(x; \xi_k, \Omega_k)\) is the bivariate normal probability density function with location vector \(\xi_k\) and covariance matrix \(\Omega_k\). - \(\Phi_1(\cdot)\) is the cumulative distribution function of the standard univariate normal distribution. - \(\alpha_k\) is the shape/skewness vector for the \(k\)-th component. - \(\Omega_k^{-1/2}\) is the inverse of the diagonal matrix of standard deviations (so \(\Omega_k = \Omega_k^{1/2} \bar{\Omega}_k \Omega_k^{1/2}\), where \(\bar{\Omega}_k\) is the correlation matrix, simplifying the skew-scaling argument).

We set \(K = 2\), with specific parameters: - \(w_1 = 0.4, \xi_1 = (0, 0), \Omega_1 = \begin{pmatrix} 2.0 & 0.8 \\ 0.8 & 1.5 \end{pmatrix}, \alpha_1 = (3, -2)^T\) - \(w_2 = 0.6, \xi_2 = (4, 3), \Omega_2 = \begin{pmatrix} 1.5 & -0.5 \\ -0.5 & 2.0 \end{pmatrix}, \alpha_2 = (-1, 4)^T\)

The targets pipeline easily adapted to this by using our min_val = -4 and max_val = 8 boundaries during training data and evaluation grid generation.

plots_f4 <- tar_read(plot_contours_f4)
htmltools::tagList(plots_f4)

11. Appendix

This section contains the overall DAG (Directed Acyclic Graph) of the targets pipeline, followed by the raw code files used to run it.

Full Targets Pipeline Graph

tar_visnetwork()

Script: _targets.R

library("targets")

tar_option_set(
  packages = c("torch", "dplyr", "ggplot2", "tidyr")
)

tar_source("R/functions.R")

list(
  # ----------------------------------------------------------------------------
  # BASELINE TARGETS (Original)
  # ----------------------------------------------------------------------------
  tar_target(
    train_data,
    generate_training_data(func = target_function_1, n = 5000, A = 2)
  ),
  tar_target(
    eval_grid,
    generate_evaluation_grid(func = target_function_1, A = 2, grid_size = 100)
  ),
  tar_target(trained_model, train_model(train_data)),
  tar_target(evaluation_results, evaluate_model(trained_model, eval_grid)),
  tar_target(plot_loss, plot_training_loss(trained_model)),
  tar_target(
    plot_contours,
    plot_results_contour(evaluation_results, trained_model = trained_model)
  ),

  # ----------------------------------------------------------------------------
  # EXERCISE 3: Number of Hidden Neurons
  # ----------------------------------------------------------------------------
  # n_hidden = 50
  tar_target(trained_model_h50, train_model(train_data, n_hidden = 50)),
  tar_target(eval_res_h50, evaluate_model(trained_model_h50, eval_grid)),
  tar_target(
    plot_contours_h50,
    plot_results_contour(eval_res_h50, trained_model = trained_model_h50)
  ),

  # n_hidden = 400
  tar_target(trained_model_h400, train_model(train_data, n_hidden = 400)),
  tar_target(eval_res_h400, evaluate_model(trained_model_h400, eval_grid)),
  tar_target(
    plot_contours_h400,
    plot_results_contour(eval_res_h400, trained_model = trained_model_h400)
  ),

  # ----------------------------------------------------------------------------
  # EXERCISE 3: Number of Hidden Layers
  # ----------------------------------------------------------------------------
  # layers = 3
  tar_target(trained_model_l3, train_model(train_data, hidden_layers = 3)),
  tar_target(eval_res_l3, evaluate_model(trained_model_l3, eval_grid)),
  tar_target(
    plot_contours_l3,
    plot_results_contour(eval_res_l3, trained_model = trained_model_l3)
  ),

  # layers = 4
  tar_target(trained_model_l4, train_model(train_data, hidden_layers = 4)),
  tar_target(eval_res_l4, evaluate_model(trained_model_l4, eval_grid)),
  tar_target(
    plot_contours_l4,
    plot_results_contour(eval_res_l4, trained_model = trained_model_l4)
  ),

  # ----------------------------------------------------------------------------
  # EXERCISE 3: Activation Functions
  # ----------------------------------------------------------------------------
  # Tanh
  tar_target(trained_model_tanh, train_model(train_data, activation = nn_tanh)),
  tar_target(eval_res_tanh, evaluate_model(trained_model_tanh, eval_grid)),
  tar_target(
    plot_contours_tanh,
    plot_results_contour(eval_res_tanh, trained_model = trained_model_tanh)
  ),

  # LeakyReLU
  tar_target(
    trained_model_leaky,
    train_model(train_data, activation = nn_leaky_relu)
  ),
  tar_target(eval_res_leaky, evaluate_model(trained_model_leaky, eval_grid)),
  tar_target(
    plot_contours_leaky,
    plot_results_contour(eval_res_leaky, trained_model = trained_model_leaky)
  ),

  # ----------------------------------------------------------------------------
  # EXERCISE 3: Learning Rates
  # ----------------------------------------------------------------------------
  # lr = 0.01
  tar_target(trained_model_lr01, train_model(train_data, lr = 0.01)),
  tar_target(eval_res_lr01, evaluate_model(trained_model_lr01, eval_grid)),
  tar_target(
    plot_contours_lr01,
    plot_results_contour(eval_res_lr01, trained_model = trained_model_lr01)
  ),

  # lr = 0.0001
  tar_target(trained_model_lr0001, train_model(train_data, lr = 0.0001)),
  tar_target(eval_res_lr0001, evaluate_model(trained_model_lr0001, eval_grid)),
  tar_target(
    plot_contours_lr0001,
    plot_results_contour(eval_res_lr0001, trained_model = trained_model_lr0001)
  ),

  # ----------------------------------------------------------------------------
  # EXERCISE 3: Custom Function (target_function_3)
  # ----------------------------------------------------------------------------
  tar_target(
    train_data_f3,
    generate_training_data(func = target_function_3, n = 5000, A = 2)
  ),
  tar_target(
    eval_grid_f3,
    generate_evaluation_grid(func = target_function_3, A = 2, grid_size = 100)
  ),
  tar_target(trained_model_f3, train_model(train_data_f3)),
  tar_target(eval_res_f3, evaluate_model(trained_model_f3, eval_grid_f3)),
  tar_target(
    plot_contours_f3,
    plot_results_contour(eval_res_f3, trained_model = trained_model_f3)
  ),

  # ----------------------------------------------------------------------------
  # EXERCISE 5: Custom Function 4 (Skew-Normal Mixture)
  # ----------------------------------------------------------------------------
  tar_target(
    train_data_f4,
    generate_training_data(func = target_function_4_wrapper, n = 5000, min_val = -4, max_val = 8)
  ),
  tar_target(
    eval_grid_f4,
    generate_evaluation_grid(func = target_function_4_wrapper, grid_size = 150, min_val = -4, max_val = 8)
  ),
  tar_target(trained_model_f4, train_model(train_data_f4)),
  tar_target(eval_res_f4, evaluate_model(trained_model_f4, eval_grid_f4)),
  tar_target(
    plot_contours_f4,
    plot_results_contour(eval_res_f4, trained_model = trained_model_f4)
  )
)

Script: R/functions.R

suppressPackageStartupMessages({
  library("torch")
  library("dplyr")
  library("ggplot2")
  library("tidyr")
  library("plotly")
  library("purrr")
  library("mvtnorm")
})

# 1. Target functions

#' Target function 1
#'
#' @param x1 A numeric vector or scalar.
#' @param x2 A numeric vector or scalar.
#' @return A numeric vector or scalar with the computed values.
target_function_1 <- function(x1, x2) {
  sin(pi * x1 / 2.0) * cos(pi * x2 / 4.0)
}

#' Target function 2
#'
#' A more complex function from the end of the Notebook.
#'
#' @param x1 A numeric vector or scalar.
#' @param x2 A numeric vector or scalar.
#' @return A numeric vector or scalar with the computed values.
target_function_2 <- function(x1, x2) {
  # A more complex function from the end of the Notebook
  exp(sin(sqrt(x1^2 + x2^2)))
}

#' Target function 3
#'
#' A custom function for Exercise 3.
#'
#' @param x1 A numeric vector or scalar.
#' @param x2 A numeric vector or scalar.
#' @return A numeric vector or scalar with the computed values.
target_function_3 <- function(x1, x2) {
  # A custom function for Exercise 3
  sin(x1) * exp(-0.1 * (x1^2 + x2^2))
}

#' Wrapper for target function 4
#'
#' @param x1 A numeric vector or scalar.
#' @param x2 A numeric vector or scalar.
#' @return A numeric vector or scalar with the computed values.
target_function_4_wrapper <- function(x1, x2) {
  weights <- c(0.4, 0.6)
  xi_list <- list(c(0, 0), c(4, 3))
  omega_list <- list(
    matrix(c(2, 0.8, 0.8, 1.5), nrow = 2),
    matrix(c(1.5, -0.5, -0.5, 2), nrow = 2)
  )
  alpha_list <- list(c(3, -2), c(-1, 4))

  target_function_4(
    x = cbind(x1, x2),
    weights = weights,
    xi_list = xi_list,
    omega_list = omega_list,
    alpha_list = alpha_list
  )
}

#' Helper function to compute the weighted density of a single bivariate skew-normal component
#'
#' @param x A numeric matrix of dimensions n x 2.
#' @param weight Mixing weight for the component.
#' @param xi Location vector of length 2.
#' @param omega Covariance matrix of dimensions 2 x 2.
#' @param alpha Skewness vector of length 2.
#' @return A numeric vector of length n representing the component's weighted densities.
calc_component_density <- function(x, weight, xi, omega, alpha) {
  x_mat <- as.matrix(x)

  # Construct the diagonal standard deviation matrix and its inverse
  omega_diag <- diag(sqrt(diag(omega)))
  omega_inv <- solve(omega_diag)

  # Center the data matrix around xi
  x_centered <- sweep(x_mat, 2, xi, "-")

  # Calculate the bivariate normal PDF
  phi_part <- mvtnorm::dmvnorm(x_mat, mean = xi, sigma = omega)

  # Calculate the skew-scaling CDF factor
  skew_arg <- as.numeric(x_centered %*% omega_inv %*% alpha)
  Phi_part <- pnorm(skew_arg)

  # Return the weighted skew normal density
  weight * 2 * phi_part * Phi_part
}

#' Probability density function for a bivariate skew-normal mixture
#'
#' @param x A numeric matrix or data frame of dimensions n x 2.
#' @param weights A numeric vector of K mixing weights.
#' @param xi_list A list of K location vectors.
#' @param omega_list A list of K covariance matrices.
#' @param alpha_list A list of K skewness vectors.
#' @return A numeric vector of length n.
target_function_4 <- function(x, weights, xi_list, omega_list, alpha_list) {
  # Normalize mixing proportions
  w <- weights / sum(weights)

  # Iterate over parameters to compute and sum the mixture components
  list(weight = w, xi = xi_list, omega = omega_list, alpha = alpha_list) |>
    purrr::pmap(\(weight, xi, omega, alpha) {
      calc_component_density(x, weight, xi, omega, alpha)
    }) |>
    purrr::reduce(`+`)
}

# 2. Generate Training Data

#' Generate training data
#'
#' @param func A function to generate the target values (`y`).
#' @param n An integer specifying the number of samples. Default is 5000.
#' @param A A numeric specifying the scaling factor for the domain. Default is 2.
#' @param min_val A numeric specifying the minimum value of the domain. Default is `-A`.
#' @param max_val A numeric specifying the maximum value of the domain. Default is `A`.
#' @return A list containing `X` (the input features) and `Y` (the target values).
generate_training_data <- function(func, n = 5000, A = 2, min_val = -A, max_val = A) {
  # Set seeds for reproducibility if needed (omitted to match original which relies on torch/numpy random state)
  x1_train <- runif(n, min = min_val, max = max_val)
  x2_train <- runif(n, min = min_val, max = max_val)
  y_train <- func(x1_train, x2_train)

  X_train <- cbind(x1_train, x2_train)
  Y_train <- matrix(y_train, ncol = 1)

  list(X = X_train, Y = Y_train)
}

# 3. Model Architecture

#' Neural network module for function approximation
#'
#' @param n_hidden Integer specifying the number of hidden units per layer.
#' @param hidden_layers Integer specifying the number of hidden layers.
#' @param activation A torch activation function.
#' @return A `torch::nn_module` object.
func_approx_net <- nn_module(
  "FuncApproxNet",
  initialize = function(
    n_hidden = 200,
    hidden_layers = 2,
    activation = nn_relu
  ) {
    modules <- list()

    # Input layer
    modules[[1]] <- nn_linear(2, n_hidden)
    modules[[2]] <- activation()

    # Hidden layers (if any)
    if (hidden_layers > 1) {
      for (i in seq_len(hidden_layers - 1)) {
        modules[[length(modules) + 1]] <- nn_linear(n_hidden, n_hidden)
        modules[[length(modules) + 1]] <- activation()
      }
    }

    # Output layer
    modules[[length(modules) + 1]] <- nn_linear(n_hidden, 1)

    # Create sequential container
    self$net <- do.call(nn_sequential, modules)
  },
  forward = function(x) {
    self$net(x)
  }
)

# 4. Training Loop

#' Train the neural network model
#'
#' @param data A list containing training data `X` and `Y`.
#' @param num_epochs Integer specifying the number of training epochs. Default is 2000.
#' @param batch_size Integer specifying the batch size. Default is 256.
#' @param lr Numeric specifying the learning rate. Default is 0.001.
#' @param n_hidden Integer specifying the number of hidden units per layer. Default is 200.
#' @param hidden_layers Integer specifying the number of hidden layers. Default is 2.
#' @param activation A torch activation function. Default is `nn_relu`.
#' @return A list with the trained model state, loss history, training duration, and architecture parameters.
train_model <- function(
  data,
  num_epochs = 2000,
  batch_size = 256,
  lr = 0.001,
  n_hidden = 200,
  hidden_layers = 2,
  activation = nn_relu
) {
  # Determine device: preferring CUDA then MPS then CPU
  device <- if (cuda_is_available()) {
    "cuda"
  } else if (backends_mps_is_available()) {
    "mps"
  } else {
    "cpu"
  }

  message(sprintf("Training on device: %s", device))

  model <- func_approx_net(
    n_hidden = n_hidden,
    hidden_layers = hidden_layers,
    activation = activation
  )
  model$to(device = device)

  X_train <- torch_tensor(data$X, dtype = torch_float32())$to(device = device)
  Y_train <- torch_tensor(data$Y, dtype = torch_float32())$to(device = device)
  N <- X_train$shape[1]

  criterion <- nn_mse_loss()
  optimizer <- optim_adam(model$parameters, lr = lr)

  loss_history <- numeric(num_epochs)

  start_time <- Sys.time()

  for (epoch in 1:num_epochs) {
    perm <- torch_randperm(N, dtype = torch_long())$to(device = device)
    epoch_loss <- 0.0
    num_batches <- 0

    # R uses 1-based indexing
    for (i in seq(1, N, by = batch_size)) {
      end_idx <- min(i + batch_size - 1, N)

      # Extract batch indices
      # perm output is a 1D tensor of 0-based indices
      idx <- as.integer(perm[i:end_idx]) + 1L

      X_batch <- X_train[idx, ]
      Y_batch <- Y_train[idx, ]

      # Forward pass
      y_pred <- model(X_batch)
      loss <- criterion(y_pred, Y_batch)

      # Backward and optimize
      optimizer$zero_grad()
      loss$backward()
      optimizer$step()

      epoch_loss <- epoch_loss + as.numeric(loss$item())
      num_batches <- num_batches + 1
    }

    avg_loss <- epoch_loss / num_batches
    loss_history[epoch] <- avg_loss

    if (epoch %% 200 == 0) {
      message(sprintf("Epoch [%d/%d], Loss: %.6f", epoch, num_epochs, avg_loss))
    }
  }

  end_time <- Sys.time()
  train_duration <- difftime(end_time, start_time, units = "secs")
  message(sprintf(
    "Training completed in %.2f seconds",
    as.numeric(train_duration)
  ))

  # Return state_dict (weights) to avoid serializing external pointers in targets
  state <- as.list(model$state_dict())
  # Convert torch tensors to pure R arrays in the state dict
  state_r <- lapply(state, function(x) as.array(x$cpu()))

  list(
    state = state_r,
    loss_history = loss_history,
    duration_secs = as.numeric(train_duration),
    n_hidden = n_hidden,
    hidden_layers = hidden_layers,
    activation = activation
  )
}

# 5. Evaluation and Visualization

#' Generate an evaluation grid
#'
#' @param func A function to compute the original target values.
#' @param A A numeric specifying the domain limit. Default is 2.
#' @param grid_size Integer specifying the number of points along each dimension. Default is 100.
#' @param min_val A numeric specifying the minimum limit of the grid. Default is `-A`.
#' @param max_val A numeric specifying the maximum limit of the grid. Default is `A`.
#' @return A data frame containing the grid points `x1`, `x2`, and target values `y_original`.
generate_evaluation_grid <- function(func, A = 2, grid_size = 100, min_val = -A, max_val = A) {
  x1_eval <- seq(min_val, max_val, length.out = grid_size)
  x2_eval <- seq(min_val, max_val, length.out = grid_size)

  grid <- expand.grid(x1 = x1_eval, x2 = x2_eval)
  grid$y_original <- func(grid$x1, grid$x2)

  grid
}

#' Evaluate the trained model
#'
#' @param trained A list returned by `train_model` containing the model state.
#' @param grid A data frame generated by `generate_evaluation_grid`.
#' @return A list containing the updated grid with predictions and errors, max error, and mean error.
evaluate_model <- function(trained, grid) {
  # Retrieve network parameters from training result or use defaults for backwards compatibility
  n_hidden <- if (!is.null(trained$n_hidden)) trained$n_hidden else 200
  hidden_layers <- if (!is.null(trained$hidden_layers)) {
    trained$hidden_layers
  } else {
    2
  }
  activation <- if (!is.null(trained$activation)) {
    trained$activation
  } else {
    nn_relu
  }

  # Re-initialize model and load weights
  model <- func_approx_net(
    n_hidden = n_hidden,
    hidden_layers = hidden_layers,
    activation = activation
  )

  # Convert pure R arrays back to torch tensors and load into model
  state_torch <- lapply(trained$state, function(x) torch_tensor(x))
  model$load_state_dict(state_torch)

  model$eval()

  X_test <- torch_tensor(
    as.matrix(grid[, c("x1", "x2")]),
    dtype = torch_float32()
  )

  with_no_grad({
    y_pred <- model(X_test)
  })

  grid$y_pred <- as.numeric(y_pred)
  grid$error <- abs(grid$y_original - grid$y_pred)

  max_error <- max(grid$error)
  mean_error <- mean(grid$error)

  message(sprintf("Max absolute error: %.6f", max_error))
  message(sprintf("Mean absolute error: %.6f", mean_error))

  list(grid = grid, max_error = max_error, mean_error = mean_error)
}

#' Plot the training loss
#'
#' @param trained A list returned by `train_model`.
#' @return A `ggplot` object showing the loss curve.
plot_training_loss <- function(trained) {
  df <- data.frame(
    epoch = seq_along(trained$loss_history),
    loss = trained$loss_history
  )
  p <- ggplot(df, aes(x = epoch, y = loss)) +
    geom_line() +
    scale_y_log10() +
    labs(x = "Epoch", y = "MSE Loss", title = "Training Loss Over Epochs") +
    theme_minimal() +
    theme(panel.grid.minor = element_blank())

  p
}

#' Plot the results as contours and surfaces
#'
#' @param eval_result A list returned by `evaluate_model`.
#' @param trained_model An optional list returned by `train_model`.
#' @return A list of `plotly` objects for the original function, predicted function, and error.
plot_results_contour <- function(eval_result, trained_model = NULL) {
  grid <- eval_result$grid
  x1_vals <- sort(unique(grid$x1))
  x2_vals <- sort(unique(grid$x2))

  z_orig <- matrix(
    grid$y_original,
    nrow = length(x1_vals),
    ncol = length(x2_vals)
  )
  z_pred <- matrix(grid$y_pred, nrow = length(x1_vals), ncol = length(x2_vals))
  z_error <- matrix(grid$error, nrow = length(x1_vals), ncol = length(x2_vals))

  create_plot_pair <- function(z_mat, title_main, z_title) {
    p_contour <- plotly::plot_ly(
      x = ~x1_vals,
      y = ~x2_vals,
      z = ~ t(z_mat),
      type = "contour",
      colorscale = "Viridis",
      contours = list(coloring = "heatmap"), # Continuous mapping
      colorbar = list(title = z_title, x = 0.45, len = 0.8)
    )

    p_surface <- plotly::plot_ly(
      x = ~x1_vals,
      y = ~x2_vals,
      z = ~ t(z_mat),
      type = "surface",
      colorscale = "Viridis",
      showscale = FALSE
    )

    subtitle <- if (
      !is.null(trained_model) && !is.null(trained_model$duration_secs)
    ) {
      sprintf(
        "<br><sup>Training Time: %.2f seconds</sup>",
        trained_model$duration_secs
      )
    } else {
      ""
    }

    full_title <- paste0(title_main, subtitle)

    plotly::subplot(p_contour, p_surface) |>
      plotly::layout(
        title = list(text = full_title),
        scene = list(
          xaxis = list(title = "x1"),
          yaxis = list(title = "x2"),
          zaxis = list(title = z_title),
          domain = list(x = c(0.55, 1), y = c(0, 1))
        ),
        xaxis = list(domain = c(0, 0.45), title = "x1"),
        yaxis = list(title = "x2"),
        annotations = list(
          list(
            x = 0.225,
            y = 1.05,
            text = "Contour Plot",
            showarrow = FALSE,
            xref = "paper",
            yref = "paper",
            xanchor = "center"
          ),
          list(
            x = 0.775,
            y = 1.05,
            text = "Surface Plot",
            showarrow = FALSE,
            xref = "paper",
            yref = "paper",
            xanchor = "center"
          )
        )
      ) |>
      plotly::config(mathjax = "cdn")
  }

  p1 <- create_plot_pair(
    z_orig,
    "Original Function: $f(x_1, x_2) = \\sin(\\pi x_1/2)\\cos(\\pi x_2/4)$",
    "f(x_1, x_2)"
  )
  p2 <- create_plot_pair(z_pred, "Neural Network Approximation", "pred")
  p3 <- create_plot_pair(z_error, "Absolute Error", "error")

  list(original = p1, predicted = p2, error = p3)
}