from sklearn.ensemble import RandomForestRegressor
The RandomForestRegressor
, from the Bagging family, is a Scikit-learn machine learning algorithm that builds a collection (forest) of decision trees for regression tasks. It combines the predictions of these trees to produce a final averaged prediction. This approach improves predictive accuracy and reduces overfitting.
How it works:
- Creates multiple decision trees: Each tree is trained on a random subset of the dataset (with replacement).
- Averages results: Predictions from all trees are averaged to make the final prediction.
Syntax Example
Basic Usage
python
CopyEdit
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
# Generate sample data
X, y = make_regression(n_samples=1000, n_features=4, random_state=0, shuffle=False)
# Train RandomForestRegressor
reg = RandomForestRegressor(max_depth=2, random_state=0)
# Fitting the model
reg.fit(X, y)
# Make a prediction
print(reg.predict([[0, 0, 0, 0]]))
Common Parameters
python
CopyEdit
reg = RandomForestRegressor(
n_estimators=200, # 200 trees
max_depth=10, # Limit tree depth to 10
min_samples_split=5, # Node splits when ≥5 samples
max_features='sqrt', # Use √features at each split
random_state=42 # Reproducibility
)
Key Parameters and Their Functions
Parameter | Default | Description | Options/Details | Best Practice Selection |
n_estimators | 100 | Number of trees. | Any positive integer. | Start with 100 , increase if needed (300-500 ) for better results. |
criterion | "squared_error" | Metric to measure quality of a split. | "squared_error" (MSE), "absolute_error" (MAE), "poisson" , "friedman_mse" . | "squared_error" for most cases. "absolute_error" for robustness to outliers. |
max_depth | None | Maximum depth of the tree. | Positive integer or None . | Limit depth ( 5-20 ) to prevent overfitting. |
min_samples_split | 2 | Minimum samples required to split an internal node. | Integer or float (fraction). | Increase (e.g., 5-10 ) for noisy datasets. |
min_samples_leaf | 1 | Minimum samples at a leaf node. | Integer or float. | Use higher values ( 2-5 ) to smooth predictions. |
max_features | 1.0 | Number of features considered at each split. | "sqrt" , "log2" , None , int, float. | "sqrt" or "log2" recommended for high-dimensional data. |
bootstrap | True | Whether sampling is done with replacement. | True or False . | True in most cases. |
oob_score | False | Whether to use out-of-bag samples to estimate generalization score. | True or False . | Set True when no separate validation set is available. |
random_state | None | Controls randomness for reproducibility. | Integer or None . | Use fixed integer (e.g., 42 ). |
n_jobs | None | Number of jobs (parallelism). | None , -1 , integer. | -1 to use all cores. |
warm_start | False | Reuse solution of previous call to add more estimators. | True or False . | True for incremental learning. |
Features of RandomForestRegressor
Feature | Description | Key Parameter |
Parallel Computation | Speed up training using multiple CPU cores. | n_jobs |
Feature Importances | Evaluate importance of each feature. | feature_importances_ |
Out-of-Bag Estimates | Validate model without separate validation data. | oob_score |
Tree Growth Control | Prevent overfitting by limiting tree growth. | max_depth , min_samples_leaf , max_leaf_nodes |
Warm Start | Expand model incrementally without retraining. | warm_start |
Expanded Features of RandomForestRegressor
with Examples
1. Parallel Computation (n_jobs
)
Use multiple CPU cores to speed up model training, especially useful for large datasets.
- Usage:
n_jobs=-1
: Use all available CPU cores.n_jobs=1
: Single core.n_jobs=2
: Two cores.- Example:
python
CopyEdit
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=10000, n_features=20, random_state=42)
reg = RandomForestRegressor(n_estimators=100, n_jobs=-1, random_state=42)
reg.fit(X, y)
print("Model trained using all available cores.")
2. Feature Importances
Random forests rank features by their contribution to the splits and prediction.
- Example:
python
CopyEdit
# Access feature importances
importances = reg.feature_importances_
# Print feature importances
for idx, importance in enumerate(importances):
print(f"Feature {idx}: Importance = {importance:.4f}")
- Output Example:
python
CopyEdit
Feature 0: Importance = 0.3945
Feature 1: Importance = 0.2661
Feature 2: Importance = 0.0701
Feature 3: Importance = 0.2693
Use this for feature selection or dimensionality reduction.
3. Out-of-Bag (OOB) Estimates
Useful for performance estimation without a validation set.
- Usage:
python
CopyEdit
reg = RandomForestRegressor(n_estimators=100, oob_score=True, bootstrap=True, random_state=42)
reg.fit(X, y)
# Print OOB score
print(f"OOB Score: {reg.oob_score_:.4f}")
- Example Output:
python
CopyEdit
OOB Score: 0.8973
4. Tree Growth Control
Control complexity to prevent overfitting.
- Key Parameters:
max_depth
min_samples_leaf
max_leaf_nodes
- Example:
python
CopyEdit
reg = RandomForestRegressor(
n_estimators=50,
max_depth=5,
min_samples_leaf=10,
max_leaf_nodes=20,
random_state=42
)
reg.fit(X, y)
5. Warm Start
Add more trees incrementally without retraining from scratch.
- Example:
python
CopyEdit
# Initial training with 50 trees
reg = RandomForestRegressor(n_estimators=50, warm_start=True, random_state=42)
reg.fit(X, y)
# Add 50 more trees
reg.n_estimators = 100
reg.fit(X, y)
print(f"Total trees after warm start: {len(reg.estimators_)}")
Practical Notes
- Overfitting Prevention: Use
max_depth
,min_samples_leaf
,max_features
. - High Dimensional Data: Handles many features well but training time increases.
- Outliers: If data has many outliers, consider
criterion="absolute_error"
for better robustness. - Missing Values: Random forests do not handle missing values automatically. Preprocessing is required.
When to Use
RandomForestRegressor is ideal for:
- Medium to large datasets.
- Datasets with complex feature interactions.
- Cases where model interpretability is less important than prediction accuracy.
- Data with non-linear relationships.