Building Synthetic Data Pipelines for ML Testing
Synthetic data addresses real ML development problems: privacy restrictions on real data, class imbalance, and edge case coverage. It does not replace real data but complements it.
This covers techniques and architectures for generating high-quality synthetic data.
Why Synthetic Data for Machine Learning?
Data Privacy and Compliance
With regulations like GDPR, CCPA, and HIPAA imposing strict requirements on personal data usage, synthetic data offers a compelling alternative:
- No PII Exposure: Generate realistic data without using actual personal information
- Regulatory Compliance: Avoid many data protection requirements when using fully synthetic data
- Risk Reduction: Minimize breach impact when sharing data with external partners
Data Availability and Balance
Real-world datasets often have limitations that synthetic data can address:
- Rare Events: Generate sufficient examples of infrequent but important scenarios
- Class Imbalance: Create balanced datasets for underrepresented classes
- Edge Cases: Systematically produce edge and corner cases for robust testing
Pipeline Development and Testing
Synthetic data enables improved development practices:
- Safe Development: Test pipelines without exposing real customer data
- CI/CD Testing: Automate tests with consistent synthetic datasets
- Schema Validation: Verify pipeline behavior under various data conditions
Techniques for Generating Synthetic Data
1. Statistical Generation
Generate data based on statistical distributions learned from real data:
# Example: Statistical synthetic data generation
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.covariance import GaussianCopula
class StatisticalSyntheticGenerator:
def __init__(self):
self.distributions = {}
self.copula = None
def fit(self, real_data: pd.DataFrame, categorical_cols: list, numerical_cols: list):
"""Learn distributions from real data"""
# Store marginal distributions for each column
for col in numerical_cols:
self.distributions[col] = {
'mean': real_data[col].mean(),
'std': real_data[col].std(),
'min': real_data[col].min(),
'max': real_data[col].max()
}
# Learn copula for numerical columns to capture correlations
numerical_data = real_data[numerical_cols].fillna(real_data[numerical_cols].median())
self.copula = GaussianCopula()
self.copula.fit(numerical_data)
# Store categorical distributions
self.categorical_cols = categorical_cols
self.categorical_distributions = {
col: real_data[col].value_counts(normalize=True).to_dict()
for col in categorical_cols
}
def generate(self, n_samples: int) -> pd.DataFrame:
"""Generate synthetic samples"""
synthetic_data = {}
# Generate numerical columns using copula
synthetic_numerical, _ = self.copula.sample(n_samples)
for i, col in enumerate(self.distributions.keys()):
synthetic_data[col] = synthetic_numerical[:, i]
# Generate categorical columns
for col in self.categorical_cols:
categories = list(self.categorical_distributions[col].keys())
probabilities = list(self.categorical_distributions[col].values())
synthetic_data[col] = np.random.choice(categories, n_samples, p=probabilities)
return pd.DataFrame(synthetic_data)
2. Deep Learning-Based Generation
For complex, high-dimensional data:
# Example: TVAE (Tabular Variational Autoencoder) for synthetic data
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
class TVAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(TVAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
return mu
def forward(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
z = self.reparameterize(mu, logvar)
return self.decoder(z), mu, logvar
def train_tvae(data: np.ndarray, epochs: int = 100, batch_size: int = 256):
"""Train TVAE on real data"""
input_dim = data.shape[1]
model = TVAE(input_dim, hidden_dim=256, latent_dim=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
dataset = TensorDataset(torch.FloatTensor(data))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
x = batch[0]
optimizer.zero_grad()
x_recon, mu, logvar = model(x)
recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {total_loss / len(data):.4f}")
return model
def generate_synthetic_data(model: TVAE, n_samples: int) -> np.ndarray:
"""Generate synthetic samples"""
model.eval()
with torch.no_grad():
z = torch.randn(n_samples, model.fc_mu.out_features)
synthetic = model.decoder(z)
return synthetic.numpy()
3. Agent-Based Simulation
For generating data with complex interactions:
# Example: Agent-based transaction simulation
import random
from datetime import datetime, timedelta
class TransactionAgent:
def __init__(self, customer_profile):
self.customer_id = customer_profile['customer_id']
self.balance = customer_profile['initial_balance']
self.spending_habit = customer_profile['spending_habit']
self.income_schedule = customer_profile['income_schedule']
def step(self, date: datetime):
"""Generate one day's transactions"""
transactions = []
# Check for income
if date.day in self.income_schedule['income_days']:
income_amount = self.income_schedule['amount']
transactions.append({
'date': date,
'type': 'deposit',
'amount': income_amount,
'balance_after': self.balance + income_amount
})
self.balance += income_amount
# Generate spending transactions
daily_transactions = random.randint(0, self.spending_habit['daily_frequency'])
for _ in range(daily_transactions):
amount = random.uniform(
self.spending_habit['min_amount'],
self.spending_habit['max_amount']
)
if self.balance >= amount:
transactions.append({
'date': date,
'type': 'purchase',
'amount': amount,
'balance_after': self.balance - amount
})
self.balance -= amount
return transactions
def simulate_transactions(n_customers: int, n_days: int):
"""Simulate transactions for multiple customers"""
all_transactions = []
date = datetime(2024, 1, 1)
for customer_id in range(n_customers):
profile = {
'customer_id': f'CUST_{customer_id:05d}',
'initial_balance': random.uniform(100, 10000),
'spending_habit': {
'daily_frequency': random.randint(0, 5),
'min_amount': random.uniform(5, 20),
'max_amount': random.uniform(20, 200)
},
'income_schedule': {
'income_days': [1, 15], # 1st and 15th of month
'amount': random.uniform(2000, 5000)
}
}
agent = TransactionAgent(profile)
for day in range(n_days):
transactions = agent.step(date)
all_transactions.extend(transactions)
date += timedelta(days=1)
return pd.DataFrame(all_transactions)
Quality Metrics for Synthetic Data
Evaluating synthetic data quality requires comparing multiple dimensions:
# Example: Synthetic data quality evaluation
from scipy.stats import ks_2samp, chi2_contingency
import pandas as pd
class SyntheticDataEvaluator:
def __init__(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):
self.real = real_data
self.synthetic = synthetic_data
def evaluate_numerical_distributions(self, column: str, threshold: float = 0.05):
"""Use KS test to compare numerical distributions"""
statistic, p_value = ks_2samp(
self.real[column].dropna(),
self.synthetic[column].dropna()
)
return {
'column': column,
'ks_statistic': statistic,
'p_value': p_value,
'passed': p_value > threshold
}
def evaluate_categorical_distributions(self, column: str, threshold: float = 0.05):
"""Use chi-square test to compare categorical distributions"""
real_counts = self.real[column].value_counts(normalize=True)
synthetic_counts = self.synthetic[column].value_counts(normalize=True)
# Align categories
all_categories = set(real_counts.index) | set(synthetic_counts.index)
real_aligned = [real_counts.get(c, 0) for c in all_categories]
synthetic_aligned = [synthetic_counts.get(c, 0) for c in all_categories]
contingency = [real_aligned, synthetic_aligned]
chi2, p_value, dof, expected = chi2_contingency(contingency)
return {
'column': column,
'chi2_statistic': chi2,
'p_value': p_value,
'passed': p_value > threshold
}
def evaluate_privacy(self, synthetic_data: pd.DataFrame, real_data: pd.DataFrame):
"""Check for potential privacy violations"""
metrics = {}
# Check for exact matches (record linkage)
synthetic_hashes = set(synthetic_data.apply(lambda x: hash(tuple(x)), axis=1))
real_hashes = set(real_data.apply(lambda x: hash(tuple(x)), axis=1))
overlap = len(synthetic_hashes & real_hashes)
metrics['exact_match_rate'] = overlap / len(synthetic_data)
metrics['privacy_passed'] = metrics['exact_match_rate'] < 0.001
return metrics
Pipeline Architecture
A production synthetic data pipeline includes:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ │ │ │ │ │
│ Real Data │────►│ Privacy │────►│ Pattern │
│ Ingestion │ │ Transformation │ │ Learning │
│ │ │ │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│
▼
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ │ │ │ │ │
│ Synthetic │◄────│ Quality │◄────│ Generation │
│ Data Output │ │ Validation │ │ Engine │
│ │ │ │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
Key components:
- Privacy Transformation: Remove or generalize direct identifiers before pattern learning
- Pattern Learning: Train generative models on transformed data
- Generation Engine: Produce synthetic records at scale
- Quality Validation: Verify synthetic data meets quality thresholds
- Output Management: Store and version synthetic datasets
Decision Rules
Use this checklist for synthetic data decisions:
- If you need to share data externally but privacy regulations prevent it, generate synthetic data
- If your training dataset has severe class imbalance, oversample with synthetic minority examples
- If you lack coverage for edge cases, synthetically generate those scenarios
- If model performance varies across data subgroups, use stratified synthetic data to debug
- If you cannot run CI/CD tests on real data due to compliance, use synthetic data as a stand-in
Synthetic data is only as good as the patterns it learns. Always validate quality before deploying to production.