Improving Out-of-Distribution Detection in Deep Neural Networks

Improving Out-of-Distribution Detection in Deep Neural Networks

In machine learning, out-of-distribution (OOD) data refers to input data that significantly differs from the data used to train the model. This data can cause deep neural network classifiers to supply incorrect and overly confident predictions, leading to severe consequences in real-world applications. Hence, OOD detection is crucial to ensuring the reliability and safety of deep neural network classifiers.

In this article, we will explore some challenges associated with OOD data and discuss comprehensive techniques that can be used to enhance its detection in machine learning models. We will also show how the Dataheroes library in Python, leveraging coresets, can be an efficient and effective tool for improving OOD detection algorithms.

Why OOD Detection Matters in Machine Learning

Consider a deep learning model trained on suburban housing data to predict future housing prices. However, the model lacks exposure to urban high-rise apartments constructed in the same region after its training. Deep learning models learn patterns from training data to make predictions but may struggle when faced with significantly different data during testing. Since the model was never exposed to such OOD properties during training, it may have difficulty providing accurate price estimates for the urban apartments. This unfamiliarity with urban property dynamics can lead to incorrect predictions, potentially affecting real estate investment decisions and market analysis.

Enhancing Model Reliability with OOD Detection

In the world of deep learning, OOD detection has emerged as a crucial aspect for ensuring the reliability and robustness of models. When deploying machine learning models in real-world scenarios, it is essential to be able to identify data instances that lie outside the training data’s distribution. This capability becomes paramount for several reasons, including model accuracy, adaptability to evolving data landscapes, and ethical AI practices. Detecting OOD data is vital for:

  • Mitigating Misclassifications: Identifying OOD data prevents models from misclassifying inputs into incorrect in-distribution classes with high confidence.
  • Ensuring Model Safety: Accurate OOD detection is essential in safety-critical domains such as autonomous vehicles and medical diagnosis, where mistaken predictions can have dire consequences.

Understanding the Importance of Distribution in Model Training

The distribution of data used during model training plays a vital role in OOD detection. Properly understanding the data distribution is necessary for accurate OOD detection.

  • Distribution Mismatch: When the distribution of real-world data deviates significantly from the training data, OOD data may be misclassified, leading to unreliable predictions. Detecting such mismatches is essential for ensuring model reliability and performance. For instance, a model trained on images of city streets may encounter difficulty accurately classifying images from rural landscapes due to differences in scenery and objects present.
  • Evolution of Data: Data is constantly changing in real-world scenarios and usually shows high variability. New classes may emerge over time, and the model must be able to handle data from unseen classes that are not present during its initial training. For example, new types of spam emails with novel content and masking techniques may arise over time in spam email detection. OOD detection is necessary to find and handle these new spam variations enabling the model to adapt to dynamic data landscapes. This adaptability is vital for the model’s continued relevance and utility in evolving.

OOD Detection using DataHeroes

1. Loading Libraries

Let’s begin by importing the required libraries. The Dataheroes library helps efficiently implement coresets, which are representative subsets of the original data, reducing computational overhead and enabling faster OOD detection algorithms. You can learn more about Coresets here.

# Importing the required libraries
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from Dataheroes import CoresetTreeServiceLG

2. Loading dataset

Let’s load two datasets to train and evaluate our model. The first dataset is the in-distribution data used for model training, while the second dataset is chosen to be OOD, meaning it represents data that is different from what the model was trained on. This OOD data will be used for evaluation purposes.

# Load in-distribution and out-of-distribution datasets as pandas DataFrames
in_distribution_df = pd.read_csv('in_distribution_dataset.csv')
out_distribution_df = pd.read_csv('out_distribution_dataset.csv')

3. Creating a Coreset for OOD detection

Since the training data can be exceptionally large, we’ll create a Coreset to efficiently represent the training data and build a Coreset tree, which will be used for OOD detection. The Coreset allows the algorithm to efficiently perform OOD detection on new data without performing computations on the entire training set.

# Prepare data directory and set the file name for the in-distribution training data
train_df, validation_df = train_test_split(in_distribution_df, test_size=0.2, random_state=42)
data_dir = Path("data")
data_dir.mkdir(parents=True, exist_ok=True)
in_distribution_train_file_path = data_dir / "in_distribution_train.csv"
 
# Store the in-distribution training dataset as a CSV
train_df.to_csv(in_distribution_train_file_path, index=False)
 
# Tell the CoresetTreeServiceLG how the data is structured
data_params = {'target': {'name': 'Label'}}
 
# Initialize the CoresetTreeServiceLG object and build the Coreset tree
service_obj = CoresetTreeServiceLG(data_params=data_params,
            	                   optimized_for='training',
                                   n_classes=2,  # Binary classification (0 or 1)
                                   n_instances=len(train_df)
                               	)
service_obj.build_from_df(train_df)

4. OOD detection with Coreset and Logistic Regression

The OOD detection is performed on the OOD dataset using the Coreset tree and the logistic regression model trained on the Coreset. The logistic regression model is trained using the Coreset samples and their associated weights.

# Get the Coreset data and weights
coreset = service_obj.get_coreset()
indices, X_coreset, y_coreset = coreset['data']
weights = coreset['w']
 
# Train a logistic regression model on the Coreset
coreset_model = LogisticRegression().fit(X_coreset, y_coreset, sample_weight=weights)
 
# Extract features from the out-of-distribution dataset
X_out_distribution = out_distribution_df[['X1', 'X2']].values
 
# Predict probabilities for the out-of-distribution samples
ood_scores = coreset_model.predict_proba(X_out_distribution)

5. Setting the Threshold for OOD detection

A suitable threshold for OOD detection is determined using the validation set. The threshold is then used to classify instances in the OOD dataset as either in-distribution or out-of-distribution.

# Use the validation set to set an appropriate threshold for Out of distribution detection
X_validation = validation_df[['X1', 'X2']].values
y_validation = validation_df['Label'].values
validation_scores = coreset_model.predict_proba(X_validation)
 
# Define a range of thresholds
thresholds = np.linspace(0, 1, 100)
best_threshold = 0
best_auc = 0
 
# Find the threshold with the highest AUC score on the validation set
for threshold in thresholds:
    is_ood = np.max(validation_scores, axis=1) < threshold
    auc = roc_auc_score(y_validation, is_ood)
    if auc > best_auc:
    	    best_auc = auc
    	    best_threshold = threshold
 
print(f"Best threshold: {best_threshold}")
 
# Use the optimized threshold to classify OOD instances
is_ood = np.max(ood_scores, axis=1) < best_threshold

6. Displaying OOD Scores and OOD Instances

Finally, the OOD scores (predicted probabilities of being out-of-distribution) and the binary classification results (is_ood) for the OOD data are printed to get an approximation of model accuracy for OOD data.

# Display OOD scores and Instances of OOD for the out-of-distribution dataset
print("OOD Scores:")
print(ood_scores)
 
print("Instances of OOD:")
print(is_ood)

By following these steps, we can efficiently perform Out-of-Distribution (OOD) detection using coresets and a logistic regression model, supplying valuable insights into the nature of unseen data.

Conclusion

OOD detection is crucial to ensuring the reliability and safety of deep neural network classifiers. Techniques such as using Python’s Dataheroes library with Coresets offer powerful tools to enhance detection through OOD machine learning algorithms. By carefully considering data distribution and employing innovative techniques, OOD machine learning practitioners can build more robust models capable of meeting the highest standards for real-world applications.

Subscribe to Our Blog

Subscribe to Our Blog

Related Articles

Hyperparameter Tuning Methods Every Data Scientist Should Know

Hyperparameter Tuning Methods Every Data Scientist Should Know

Learn More
Unleashing the Power of ML: The Art of Training Models and Its Vital Significance

Unleashing the Power of ML: The Art of Training Models and Its Vital Significance

Learn More
Comparing Customer Segmentation Techniques: KMeans vs. KMeans Coreset from DataHeroes

Comparing Customer Segmentation Techniques: KMeans vs. KMeans Coreset from DataHeroes

Learn More