GSoC'24 Week 1 & 2 Progress: Building a Cross-Validation Framework in Octave
In the past two weeks, I have been working on enhancing the statistics package for GNU Octave by implementing a cross-validation framework for classification models. The focus has been on implementing a 'crossval' method in the 'ClassificationKNN' class and creating a ‘ClassificationPartitionedModel’ class and a ‘kfoldPredict’ method. Together, these provide a solid foundation for cross-validation in Octave. This blog will walk you through the progress, challenges, and solutions implemented.
Cross-validation is a method used to evaluate and improve the performance of machine learning models. It is essential in classification models, where the goal is to categorize data into predefined classes. Cross-validation helps assess how well a model will generalize to an independent dataset, which is critical for ensuring its robustness and reliability.
The process involves partitioning the dataset into a set of folds, training the model on some folds, and then validating it on the remaining fold. This cycle is repeated until each fold has been used as a validation set. This method provides a comprehensive evaluation as it tests the model on multiple subsets of the data.
To start, I focused on setting up the ‘ClassificationPartitionedModel’ class. This class is designed to handle cross-validation for classification models. The constructor for ‘ClassificationPartitionedModel’ was implemented to initialize the properties and to train the models on k-1 folds while reserving 1 fold for validation. This involved handling different types of classification models, with the initial implementation focusing on ClassificationKNN. Here's a breakdown of the constructor:
- Initialization of Basic Properties: The constructor initializes fundamental properties like X, Y, ClassNames, Cost, CrossValidatedModel, KFold, ModelParameters, NumObservations, Partition, PredictorNames, Prior, ResponseName, and Trained.
- Training Initialization: The constructor initializes the training process by setting up the partition object for cross-validation and preparing a cell array to store the trained models.
- Model Training Loop: The constructor iterates over each fold, creating training and validation sets, training the model on the training set, and storing the trained model.
- Object Construction: Finally, the constructor creates the ClassificationPartitionedModel object with the initialized properties and the trained models.
Next, I implemented the 'crossval' method in the ‘ClassificationKNN’ class. This method sets up the partitioning for k-fold cross-validation and creates an instance of ‘ClassificationPartitionedModel.’ Here's a breakdown of the 'crossval' method:
- Input Parsing: The method parses input arguments to retrieve cross-validation parameters like the number of folds (KFold).
- Partition Object Creation: It creates a partition object to define the training and validation splits for k-fold cross-validation.
- Model Parameters and Initialization: The method initializes the model parameters structure and a cell array to store the models trained on each fold.
- Training Loop: It iterates over each fold, setting up training and validation indices, and trains the model on the training set for each fold. The trained models are stored in the cell array.
- Cross-Validation Model Creation: Finally, the method creates and returns an instance of the ‘ClassificationPartitionedModel’ class, initialized with the partitioned data and the trained models,
- Prediction Initialization: The method initializes a prediction vector label based on the type of Y.
- Score and Cost Matrices: Initializes matrices for storing classification scores and costs.
- Model Prediction Loop: Iterates over each fold, predicting the class labels, scores, and costs for the held-out validation set.
- Result Aggregation: Aggregates the predictions from all folds to form the final result.
- It ensures the correct initialization of the 'ClassificationPartitionedModel' object.
- Validating the 'kfoldPredict' and 'crossval' output against expected results.
Here is a simple demonstration using the fisheriris dataset:
load fisheriris
x = meas;
y = species;
## Create a KNN classifier model
obj = fitcknn (x, y, "NumNeighbors", 5, "Standardize", 1);
## Perform cross-validation
cvMdl = crossval (obj, 'KFold', 5);
## Predict the class labels for the observations not used for training
[label, score, cost] = kfoldPredict (cvMdl);
You can see the implementations here.
Significant progress has been made in implementing a cross-validation framework for classification models in Octave. The 'ClassificationPartitionedModel' class and 'kfoldPredict' method, along with the 'crossval' method in 'ClassificationKNN', provide a strong foundation for further enhancements and integrations.
Suggestions and feedback are welcome 😃.
Comments
Post a Comment