An end-to-end platform integrating tensor-train decomposition, graph attention networks, and transformer-based sequence models on Google Cloud Platform to model G×E×M interactions at scale for multi-environment plant breeding.
The platform unifies statistical genetics with modern deep learning and cloud infrastructure across four tightly coupled layers — from raw multi-modal inputs to interpretable breeding recommendations.
Genomic markers, environmental climate time-series, and management encodings are fused via tensor-train (TT) decomposition, graph attention networks (GAT) over genotype–environment graphs, and a transformer-based climate sequence encoder. Produces joint embeddings that capture nonlinear G×E×M interactions inaccessible to linear tensor methods.
Containerized microservices on GKE, distributed multi-GPU training via Kubernetes and PyTorch DDP, scalable ETL pipelines through Cloud Dataflow (Apache Beam), and serverless inference endpoints via Cloud Functions. Orchestrated by Vertex AI Pipelines with Bayesian hyperparameter tuning.
Integrated gradient (IG) attribution per input modality via Captum, GAT attention weight visualization across the bipartite genotype–environment graph, and SHAP-based management factor decomposition — giving breeders interpretable, actionable insights alongside predictions.
MCAR-robust Tucker imputation (validated at up to 75% missing), environment-aware masking schedules during training, and multi-task prediction heads that share latent representations across correlated phenotypic traits (yield, heading days, plant height, grain protein, TGW).
End-to-end pipeline from raw multi-modal inputs through modality-specific encoders, tensor-train fusion, distributed GCP training, to serverless inference and explainability outputs. Click any layer to expand details.
Three neural components — tensor-train networks, graph attention networks, and transformer-based sequence models — work in concert to extract, align, and fuse signals from genomic markers, climate sequences, and management treatments.
The G×E×M tensor X ∈ ℝ^{I×J×K×T} (genotypes × environments × management levels × traits) is inherently high-dimensional and sparse. We parameterize the interaction term using a tensor-train (TT) decomposition rather than a full Tucker decomposition, reducing parameter count from O(IJKd) to O((I+J+K)·r·d) while preserving expressivity. TT-cores are learned jointly with the encoder networks via end-to-end backpropagation.
# Tensor-Train core contraction for G×E×M fusion import torch, torch.nn as nn class TensorTrainFusion(nn.Module): def __init__(self, d=256, r=32): super().__init__() # TT-cores for Genotype, Environment, Management self.G_core = nn.Parameter(torch.randn(d, r)) # ℝ^{d × r} self.E_core = nn.Parameter(torch.randn(r, r)) # ℝ^{r × r} self.M_core = nn.Parameter(torch.randn(r, d)) # ℝ^{r × d_fuse} self.norm = nn.LayerNorm(d) def forward(self, z_G, z_E, z_M): # z_G: (I,d) z_E: (J,d) z_M: (K,d) h = torch.einsum('id,dr->ir', z_G, self.G_core) # (I, r) h = torch.einsum('ir,rr->ir', h, self.E_core) # (I, r) h = torch.einsum('ir,rd->id', h, self.M_core) # (I, d) return self.norm(h + z_G) # residual connection
A bipartite graph G = (V_G ∪ V_E, E) is constructed from trial data: edge (i,j) ∈ E exists if genotype i was evaluated in environment j. Edge features encode the management level, missing-data indicator, aridity index, and trial year. GAT layers propagate environment-specific stress signals to genotype nodes, enabling non-local G×E interactions that linear tensor methods cannot capture.
# Heterogeneous GAT over bipartite genotype–environment graph from torch_geometric.nn import GATConv, HeteroConv gat_conv = HeteroConv({ ('genotype', 'tested_in', 'environment'): GATConv(in_channels=256, out_channels=256, heads=2, concat=False, edge_dim=8), ('environment', 'hosts', 'genotype'): GATConv(in_channels=256, out_channels=256, heads=2, concat=False, edge_dim=8), }, aggr='mean') # edge_attr: [management_level, missing_flag, # env_aridity_index, trial_year_normalized] x_dict = gat_conv(x_dict, edge_index_dict, edge_attr_dict)
Each environment j is characterized by a daily climate sequence from CHIRPS/ERA5 — precipitation, min/max temperature, solar radiation, and vapor pressure deficit over the growing season (~180 days). A 6-layer Transformer encoder with a prepended [CLS] token encodes this sequence into a fixed-length environment embedding z_E ∈ ℝ^{256}.
# Climate Transformer Encoder (per environment, 180-day sequence) class ClimateTransformer(nn.Module): def __init__(self, n_vars=5, d_model=256, nhead=8, num_layers=6): super().__init__() self.embed = nn.Linear(n_vars, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=1024, dropout=0.1, batch_first=True) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) self.pool = nn.Linear(d_model, d_model) self.cls = nn.Parameter(torch.zeros(1, 1, d_model)) def forward(self, x): # x: (B, T=180, n_vars) h = self.embed(x) # (B, T, d) cls = self.cls.expand(h.size(0), -1, -1) # (B, 1, d) h = torch.cat([cls, h], dim=1) # (B, T+1, d) h = self.encoder(h) # (B, T+1, d) return self.pool(h[:, 0, :]) # CLS token → (B, d)
I=500 genotypes · J=50 environments · K=8 nitrogen levels · T=3 traits · 2 replications. True structure: Tucker-5 decomposition + additive environmental stress. 75% MCAR missing rate. Known ground truth allows exact variance recovery measurement.
I=300 genotypes · J=30 environments · K=4 nitrogen levels · T=5 traits · 2 replications · 820 SNP markers · daily CHIRPS/ERA5 climate. 62% missing rate. Real-world multi-environment trial dataset.
Yield (primary) · Days to heading · Plant height · Grain protein content · Thousand-grain weight (TGW). Multi-task learning with shared latent trunk and per-trait output heads, evaluated under 5-fold CV-Env and CV-Geno schemes.
A production-grade cloud engineering stack built on GCP primitives: containerized microservices, Kubernetes-orchestrated distributed training, Dataflow-managed ETL, and serverless inference via Cloud Functions.
gs://gxexm-raw/.gs://gxexm-processed/, partitioned by dataset and modality.gs://gxexm-models/.
gcr.io/gxexm/trainer:latest.
POST /predict accepts genotype IDs, environment descriptors, management levels. Returns μ̂, σ̂, attention weights, and IG attributions as JSON.
FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-runtime WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # requirements: torch-geometric captum shap google-cloud-storage # tensorboard scipy pandas scikit-learn COPY src/ . ENV PYTHONPATH=/app ENTRYPOINT ["python", "-m", "trainer.train"]
import functions_framework, torch, json from google.cloud import storage MODEL = None def load_model(): client = storage.Client() blob = client.bucket("gxexm-models").blob("latest/model.pt") blob.download_to_filename("/tmp/model.pt") return torch.jit.load("/tmp/model.pt").eval() @functions_framework.http def predict(request): global MODEL if MODEL is None: MODEL = load_model() # warm once per instance payload = request.get_json() with torch.no_grad(): mu, sigma, attn = MODEL( torch.tensor(payload["snp"]), torch.tensor(payload["climate"]), torch.tensor(payload["management"])) return json.dumps({ "mu": mu.tolist(), "sigma": sigma.tolist(), "attention": attn.tolist() })
Two datasets, five-fold cross-validation, eight competing methods spanning classical statistics, tensor decompositions, machine learning, and deep learning baselines.
I=500 genotypes · J=50 environments · K=8 nitrogen management levels · T=3 traits · 2 replications. True structure generated via Tucker five-component decomposition with additive environmental stress per environment. 75% MCAR missing data rate.
I=300 genotypes · J=30 environments · K=4 nitrogen management levels · T=5 traits (yield, heading days, plant height, grain protein, TGW) · 2 replications · 820 SNP markers · daily CHIRPS/ERA5 climate data per environment. 62% missing data rate.
| Method | Category | Handles Nonlinearity | Handles Multi-Modal | Scales to Large I,J |
|---|---|---|---|---|
| LMM (FA structure) | Classical statistics | No | No | No |
| AMMI | Classical statistics | No | No | Partially |
| CP Decomposition (R=5) | Tensor methods | No | No | Yes |
| Tucker Decomposition | Tensor methods | No | No | Yes |
| Random Forest | Machine learning | Yes | Partially | Yes |
| XGBoost | Machine learning | Yes | Partially | Yes |
| DeepGS (MLP on markers) | Deep learning | Yes | No | Yes |
| G×E×M Platform (ours) | Deep learning | Yes | Yes | Yes |
| Method | RMSE (CV-Env) | ρ Top-20% (CV-Env) | RMSE (CV-Geno) | ρ Top-20% (CV-Geno) |
|---|---|---|---|---|
| LMM (FA) | 0.42 | 0.61 | 0.51 | 0.55 |
| CP (R=5) | 0.28 | 0.73 | 0.38 | 0.67 |
| Tucker (4,3,2) | 0.31 | 0.71 | 0.41 | 0.64 |
| XGBoost | 0.33 | 0.69 | 0.44 | 0.62 |
| DeepGS | 0.30 | 0.72 | 0.37 | 0.68 |
| G×E×M Platform (ours) | 0.21 | 0.83 | 0.29 | 0.77 |
RMSE — CV-Environment (lower is better)
The improvement over the CP baseline (RMSE 0.28 → 0.21, a 25% reduction) provides evidence for the capacity of nonlinear deep learning to capture interaction patterns that the best linear tensor methods cannot. The rank correlation improvement for top-20% genotypes (0.73 → 0.83) is of direct practical value — correctly identifying elite candidates reduces the cost of unnecessary field trials on inferior varieties. The platform recovered 94% of the Tucker-factor variance (vs 87% for standard CP factor analysis) and identified nonlinear environmental stress patterns driven by geographic climate variation that were not recoverable from order-sorted CP factors.
| Dataset Scale | Method | Training Time | Inference (batch 1k) | Cloud Cost (USD) |
|---|---|---|---|---|
| Small (I=120, J=15, K=4) | CP/Tucker | <1 min (CPU) | <1 sec | $0.10 |
| Small (I=120, J=15, K=4) | G×E×M Platform | 8 min (1 GPU) | 2 sec | $0.50 |
| Medium (I=500, J=50, K=8) | CP/Tucker | 15 min (CPU) | 5 sec | $1.00 |
| Medium (I=500, J=50, K=8) | G×E×M Platform | 45 min (4 GPUs) | 8 sec | $4.00 |
| Large (I=5000, J=200, K=20) | CP/Tucker | 3 hr (CPU cluster) | 2 min | $30 |
| Large (I=5000, J=200, K=20) | G×E×M Platform | 2 hr (16 GPUs) | 45 sec | $80 |
At large scale, GPU-accelerated distributed training makes the deep learning approach competitive with CPU-based tensor methods in wall-clock time despite added model complexity. At the medium scale typical of most national breeding programs (I≈500, J≈50), a complete run costs approximately $4 — accessible to programs without institutional HPC resources.
Three complementary explainability techniques provide breeders with interpretable, modality-specific attributions — from genomic loci to climate events to management factor rankings — returned alongside every inference call.
from captum.attr import IntegratedGradients ig = IntegratedGradients(model.forward_snp) # Baseline: all-zero SNP vector (no genetic information) baseline_snp = torch.zeros_like(snp_input) attributions, delta = ig.attribute( inputs=snp_input, baselines=baseline_snp, additional_forward_args=(climate_input, mgmt_input), n_steps=50, return_convergence_delta=True ) # attributions: (I, n_snp) — per-marker contribution to yield # delta: convergence check — require |delta| < 0.05 per_marker_importance = attributions.abs().mean(dim=0) # (n_snp,)
# Extract attention weights from heterogeneous GAT during forward pass class InstrumentedGAT(nn.Module): def forward(self, x_dict, edge_index_dict, edge_attr_dict): x_out, attn = self.gat_conv( x_dict, edge_index_dict, edge_attr_dict, return_attention_weights=True) self.last_attn = {} for key, (edge_idx, weights) in attn.items(): self.last_attn[key] = { "edges": edge_idx.cpu().numpy(), "weights": weights.mean(dim=-1).cpu().numpy() # mean over heads → scalar per edge } return x_out
The REST inference endpoint returns yield predictions (μ̂, σ̂) together with: per-modality IG attribution vectors, top-10 high-attention environment neighbors per genotype, and SHAP values for each management level. All outputs are JSON-serialized for downstream integration into breeding information management systems (BIMS).
IG convergence deltas are required to remain below 0.05 for attributions to be reported. GAT attention weights are cross-validated against known G×E interaction patterns from the CIMMYT dataset. SHAP values are verified against marginal nitrogen response curves from published agronomic literature.