diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 8caf2aa8..aab6ae6a 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -6,9 +6,9 @@ Please check the following items before submitting your PR: -- [ ] I have created a new folder and YAML metadata file `models//.yml` for my submission. `arch_name` is the name of the architecture and `model_variant.yml` includes things like author details, training set names and/or important hyperparameters. +- [ ] I have created a new folder and YAML metadata file `models//.yml` for my submission. `arch_name` is the name of the architecture and `model_variant.yml` includes things like author details, training set names and important hyperparameters. - [ ] I have added the my new model as a new attribute on the [`Model.` enum](https://github.com/janosh/matbench-discovery/blob/57d0d0c8a14cd317/matbench_discovery/enums.py#L274) in `enums.py`. -- [ ] I have uploaded the energy/force/stress model prediction file for the WBM test set to Figshare or another cloud storage service (`--preds.csv.gz`). +- [ ] I have uploaded the energy/force/stress model prediction file for the WBM test set to Figshare or another cloud storage service (`--preds.csv.gz`). - [ ] I have uploaded the model-relaxed structures file to Figshare or another cloud storage service (`-wbm-IS2RE-FIRE.json.gz`). - [ ] I have uploaded the phonon predictions to Figshare or another cloud storage service (`-kappa-103-FIRE-.gz`). - [ ] I have included the urls to the Figshare files in the YAML metadata file (`models//.yml`). If not using Figshare I have included the urls to the cloud storage service in the description of the PR. diff --git a/models/sevennet/sevennet-mf-ompa.yml b/models/sevennet/sevennet-mf-ompa.yml index f1700643..b201dc82 100644 --- a/models/sevennet/sevennet-mf-ompa.yml +++ b/models/sevennet/sevennet-mf-ompa.yml @@ -101,6 +101,24 @@ metrics: pred_file: models/sevennet/sevennet-mf-ompa/2025-03-11-wbm-geo-opt-FIRE.json.gz pred_file_url: https://figshare.com/files/52983491 struct_col: sevennet_structure + symprec=1e-2: + rmsd: 0.0115 # Å + n_sym_ops_mae: 1.7053 # unitless + symmetry_decrease: 0.0467 # fraction + symmetry_match: 0.8181 # fraction + symmetry_increase: 0.128 # fraction + n_structures: 256963 # count + analysis_file: models/sevennet/sevennet-mf-ompa/2025-03-11-wbm-geo-opt-FIRE-symprec=1e-2-moyo=0.4.2.csv.gz + analysis_file_url: https://figshare.com/files/53029115 + symprec=1e-5: + rmsd: 0.0115 # Å + n_sym_ops_mae: 2.0326 # unitless + symmetry_decrease: 0.0439 # fraction + symmetry_match: 0.7057 # fraction + symmetry_increase: 0.2453 # fraction + n_structures: 256963 # count + analysis_file: models/sevennet/sevennet-mf-ompa/2025-03-11-wbm-geo-opt-FIRE-symprec=1e-5-moyo=0.4.2.csv.gz + analysis_file_url: https://figshare.com/files/53029142 discovery: pred_file: models/sevennet/sevennet-mf-ompa/2025-03-11-wbm-IS2RE.csv.gz pred_file_url: https://figshare.com/files/52983488 diff --git a/site/src/lib/GeoOptMetricsTable.svelte b/site/src/lib/GeoOptMetricsTable.svelte index c17de925..bf885f54 100644 --- a/site/src/lib/GeoOptMetricsTable.svelte +++ b/site/src/lib/GeoOptMetricsTable.svelte @@ -157,7 +157,7 @@ {#snippet cell({ col, val })} {#if col.label === `Links` && val} {@const links = val} - {#each links.files as { url: href, title, icon } (href)} + {#each links.files as { url: href, title, icon } (title + href)} {#if href} {@html icon} diff --git a/site/src/lib/HeatmapTable.svelte b/site/src/lib/HeatmapTable.svelte index ddbd3121..8a8e1e06 100644 --- a/site/src/lib/HeatmapTable.svelte +++ b/site/src/lib/HeatmapTable.svelte @@ -8,11 +8,7 @@ import { titles_as_tooltips } from 'svelte-zoo/actions' import { flip } from 'svelte/animate' import { writable } from 'svelte/store' - import type { HeatmapColumn } from './types' - - type CellVal = string | number | undefined | null - type RowData = Record - type TableData = RowData[] + import type { CellVal, HeatmapColumn, RowData, TableData } from './types' interface Props { data: TableData @@ -20,6 +16,10 @@ sort_hint?: string style?: string | null cell?: Snippet<[{ row: RowData; col: HeatmapColumn; val: CellVal }]> + controls?: Snippet + initial_sort_column?: string + initial_sort_direction?: `asc` | `desc` + fixed_header?: boolean } let { @@ -28,9 +28,19 @@ sort_hint = `Click on column headers to sort table rows`, style = null, cell, + controls, + initial_sort_column, + initial_sort_direction, + fixed_header = false, }: Props = $props() - const sort_state = writable({ column: ``, ascending: true }) + // Add container reference for binding + let container: HTMLDivElement + + const sort_state = writable({ + column: initial_sort_column || ``, + ascending: initial_sort_direction !== `desc`, + }) let clean_data = $state(data) $effect(() => { @@ -70,13 +80,27 @@ }) } - function calc_color(value: number | string | undefined, col: HeatmapColumn) { - if (col.color_scale === null || typeof value !== `number`) + function calc_color(value: number | string | undefined | null, col: HeatmapColumn) { + // Skip color calculation for null values or if color_scale is null + if ( + value === null || + value === undefined || + col.color_scale === null || + typeof value !== `number` + ) { return { bg: null, text: null } + } const col_id = get_col_id(col) - const values = clean_data.map((row) => row[col_id]) - const range = [min(values) ?? 0, max(values) ?? 1] + const numericValues = clean_data + .map((row) => row[col_id]) + .filter((val): val is number => typeof val === `number`) // Type guard to ensure we only get numbers + + if (numericValues.length === 0) { + return { bg: null, text: null } + } + + const range = [min(numericValues) ?? 0, max(numericValues) ?? 1] if (col.better === `lower`) { range.reverse() } @@ -116,14 +140,28 @@ } -
- + +
+ {#if Object.keys($sort_state).length && sort_hint} +
{sort_hint}
+ {/if} + + + {#if controls} +
+ {@render controls()} +
+ {/if} +
+ +
+
{#if visible_columns.some((col) => col.group)} - {#each visible_columns as {label, group, tooltip} (label + group)} + {#each visible_columns as { label, group, tooltip } (label + group)} {#if !group} {:else} @@ -169,13 +207,13 @@ style:background-color={color.bg} style:color={color.text} style={col.style} - title={[undefined, null].includes(val) ? `not available` : null} + title={typeof val === `undefined` || val === null ? `not available` : null} > {#if cell} {@render cell({ row, col, val })} {:else if typeof val === `number` && col.format} {pretty_num(val, col.format)} - {:else if [undefined, null].includes(val)} + {:else if val === undefined || val === null} n/a {:else} {@html val} @@ -191,15 +229,11 @@ diff --git a/site/src/lib/MetricsTable.svelte b/site/src/lib/MetricsTable.svelte index ddaeba5b..80d4c6d9 100644 --- a/site/src/lib/MetricsTable.svelte +++ b/site/src/lib/MetricsTable.svelte @@ -1,50 +1,114 @@ - + + {#snippet controls()} + { + handle_filter_change(show_energy, show_noncomp) + }} + on_col_change={(column, visible) => { + // Update visible_cols state instead of columns + visible_cols[column] = visible + }} + /> + {/snippet} + {#snippet cell({ col, val })} - {#if col.label === `Links` && val} - {@const links = val} - {#each [links.paper, links.repo, links.pr_url] as {title, url, icon} (title + url)} - {#if url} - - {icon} + {#if col.label === `Links` && val && typeof val === `object` && `paper` in val} + {@const links = val as LinkData} + {#each [links.paper, links.repo, links.pr_url] as link (link?.title + link?.url)} + {#if link?.url} + + {link.icon} {/if} {/each} @@ -202,7 +419,7 @@ {/if} {:else if typeof val === `number` && col.format} {pretty_num(val, col.format)} - {:else if [undefined, null].includes(val)} + {:else if val === undefined || val === null} n/a {:else} {@html val} @@ -230,7 +447,7 @@

Download prediction files for {active_model_name}

    - {#each active_files as {name, url} (name + url)} + {#each active_files as { name, url } (name + url)}
  1. {name} @@ -242,6 +459,13 @@ diff --git a/site/src/lib/TableColumnToggleMenu.svelte b/site/src/lib/TableColumnToggleMenu.svelte index ecd6fea3..1949a3f9 100644 --- a/site/src/lib/TableColumnToggleMenu.svelte +++ b/site/src/lib/TableColumnToggleMenu.svelte @@ -59,6 +59,7 @@ min-width: 150px; display: grid; grid-template-columns: repeat(auto-fill, minmax(120px, 1fr)); + z-index: 1; /* needed to ensure column toggle menu is above HeatmapTable header row */ } .column-menu label { display: inline-block; diff --git a/site/src/lib/TableControls.svelte b/site/src/lib/TableControls.svelte new file mode 100644 index 00000000..9dcdbcb6 --- /dev/null +++ b/site/src/lib/TableControls.svelte @@ -0,0 +1,114 @@ + + +
    + MP v2022.10.28 release)
    + We still show these models behind a toggle as we expect them
    to nonetheless + provide helpful signals for developing future models. + + {/snippet} + + + + + + +
    + + diff --git a/site/src/lib/index.ts b/site/src/lib/index.ts index 953771be..29e15436 100644 --- a/site/src/lib/index.ts +++ b/site/src/lib/index.ts @@ -18,8 +18,10 @@ export { default as ModelCard } from './ModelCard.svelte' export { default as Nav } from './Nav.svelte' export { default as PtableHeatmap } from './PtableHeatmap.svelte' export { default as PtableInset } from './PtableInset.svelte' +export { default as RadarChart } from './RadarChart.svelte' export { default as References } from './References.svelte' export { default as TableColumnToggleMenu } from './TableColumnToggleMenu.svelte' +export { default as TableControls } from './TableControls.svelte' export * from './types' export { data_files } diff --git a/site/src/lib/metrics.ts b/site/src/lib/metrics.ts index f4d9c1fc..f280d286 100644 --- a/site/src/lib/metrics.ts +++ b/site/src/lib/metrics.ts @@ -1,4 +1,4 @@ -import type { DiscoverySet, HeatmapColumn } from './types' +import type { CombinedMetricConfig, DiscoverySet, HeatmapColumn } from './types' export const METADATA_COLS: HeatmapColumn[] = [ { label: `Model`, sticky: true }, @@ -45,7 +45,37 @@ export const PHONON_METRICS: HeatmapColumn[] = [ }, ] -export const ALL_METRICS: HeatmapColumn[] = [...DISCOVERY_METRICS, ...PHONON_METRICS] +// Define geometry optimization metrics +export const GEO_OPT_METRICS: HeatmapColumn[] = [ + { + label: `RMSD`, + tooltip: `Root mean squared displacement between predicted and reference structures after relaxation`, + style: `border-left: 1px solid black;`, + }, + { + label: `Energy Diff`, + tooltip: `Mean absolute energy difference between predicted and reference structures`, + }, + { + label: `Force RMSE`, + tooltip: `Root mean squared error of forces in predicted structures relative to reference`, + }, + { + label: `Stress RMSE`, + tooltip: `Root mean squared error of stress in predicted structures relative to reference`, + }, + { + label: `Max Force`, + tooltip: `Maximum force component in predicted structures after relaxation`, + }, +] + +// Update ALL_METRICS to include GEO_OPT_METRICS +export const ALL_METRICS: HeatmapColumn[] = [ + ...DISCOVERY_METRICS, + ...PHONON_METRICS, + ...GEO_OPT_METRICS.slice(0, 1), // Only include RMSD by default, others can be toggled +] export const DISCOVERY_SET_LABELS: Record< DiscoverySet, @@ -65,3 +95,146 @@ export const DISCOVERY_SET_LABELS: Record< tooltip: `Metrics computed on the 10k structures predicted to be most stable (different for each model)`, }, } + +export const [F1_DEFAULT_WEIGHT, RMSD_DEFAULT_WEIGHT, KAPPA_DEFAULT_WEIGHT] = [ + 0.5, 0.1, 0.4, +] + +export const DEFAULT_COMBINED_METRIC_CONFIG: CombinedMetricConfig = { + name: `CPS`, + description: `Combined Performance Score weighting discovery, structure optimization, and phonon performance`, + weights: [ + { + metric: `F1`, + label: `F1`, + description: `F1 score for stable/unstable material classification (discovery task)`, + value: F1_DEFAULT_WEIGHT, + }, + { + metric: `kappa_SRME`, + label: `ÎșSRME`, + description: `Symmetric relative mean error for thermal conductivity prediction (lower is better)`, + value: KAPPA_DEFAULT_WEIGHT, + }, + { + metric: `RMSD`, + label: `RMSD`, + description: `Root mean square displacement for crystal structure optimization`, + value: RMSD_DEFAULT_WEIGHT, + }, + ], +} + +// F1 score is between 0-1 where higher is better (no normalization needed) +function normalize_f1(value: number | undefined): number { + if (value === undefined || isNaN(value)) return 0 + return value // Already in [0,1] range +} + +// RMSD is lower=better, with current models in the range of ~0.02-0.25 Å +// We invert this so that better performance = higher score +function normalize_rmsd(value: number | undefined): number { + if (value === undefined || isNaN(value)) return 0 + + // Fixed reference points for RMSD (in Å) + const excellent = 0 // Perfect performance (atoms in exact correct positions) + const baseline = 0.3 // in Å, a reasonable baseline for poor performance given worst performing model at time of writing is AlphaNet-MPTrj at 0.0227 Å + + // Linear interpolation between fixed points with clamping + // Inverse mapping since lower RMSD is better + if (value <= excellent) return 1.0 + if (value >= baseline) return 0.0 + return (baseline - value) / (baseline - excellent) +} + +// kappa_SRME is symmetric relative mean error, with range [0,2] by definition +// Lower values are better (0 is perfect) +function normalize_kappa_srme(value: number | undefined): number { + if (value === undefined || isNaN(value)) return 0 + + // Simple linear normalization from [0,2] to [1,0] + // No clamping needed as SRME is bounded by definition + return Math.max(0, 1 - value / 2) +} + +/** + * Calculate a combined score using normalized metrics weighted by importance factors. + * This uses fixed normalization reference points to ensure score stability when new models are added. + * + * Normalization reference points: + * - F1: Already in [0,1] range, higher is better + * - RMSD: 0.0Å (perfect) to 0.25Å (baseline), lower is better + * - Îș_SRME: Range [0,2] linearly mapped to [1,0], lower is better + * + * @param f1 F1 score for discovery + * @param rmsd Root mean square displacement in Å + * @param kappa Symmetric relative mean error for thermal conductivity + * @param config Configuration with weights for each metric + * @returns Combined score between 0-1, or NaN if any weighted metric is missing + */ +export function calculate_combined_score( + f1: number | undefined, + rmsd: number | undefined, + kappa: number | undefined, + config: CombinedMetricConfig, +): number { + // Find weights from config by metric names + const f1_weight = + config.weights.find((w) => w.metric === `F1`)?.value ?? F1_DEFAULT_WEIGHT + const rmsd_weight = + config.weights.find((w) => w.metric === `RMSD`)?.value ?? RMSD_DEFAULT_WEIGHT + const kappa_weight = + config.weights.find((w) => w.metric === `kappa_SRME`)?.value ?? KAPPA_DEFAULT_WEIGHT + + // Check if any weighted metric is missing - if so, return NaN + if ( + (f1_weight > 0 && f1 === undefined) || + (rmsd_weight > 0 && rmsd === undefined) || + (kappa_weight > 0 && kappa === undefined) + ) { + return NaN + } + + // Get normalized metric values + const normalized_f1 = normalize_f1(f1) + const normalized_rmsd = normalize_rmsd(rmsd) + const normalized_kappa = normalize_kappa_srme(kappa) + + // Get available weights and metrics + const available_metrics = [] + const available_weights = [] + + // Only include metrics that are available + if (f1 !== undefined) { + available_metrics.push(normalized_f1) + available_weights.push(f1_weight) + } + + if (rmsd !== undefined) { + available_metrics.push(normalized_rmsd) + available_weights.push(rmsd_weight) + } + + if (kappa !== undefined) { + available_metrics.push(normalized_kappa) + available_weights.push(kappa_weight) + } + + // If no metrics are available, return 0 + if (available_metrics.length === 0) return 0 + + // Normalize weights to sum to 1 based on available metrics + const weight_sum = available_weights.reduce((sum, w) => sum + w, 0) + const normalized_weights = + weight_sum > 0 + ? available_weights.map((w) => w / weight_sum) + : available_weights.map(() => 1 / available_weights.length) + + // Calculate weighted average + let score = 0 + for (let i = 0; i < available_metrics.length; i++) { + score += available_metrics[i] * normalized_weights[i] + } + + return score +} diff --git a/site/src/lib/types.ts b/site/src/lib/types.ts index 47427b35..b842e2b1 100644 --- a/site/src/lib/types.ts +++ b/site/src/lib/types.ts @@ -111,3 +111,21 @@ export type DiatomicsCurves = { 'homo-nuclear': Record 'hetero-nuclear'?: Record } + +// MetricWeight defines weights for each component of the combined score +export interface MetricWeight { + metric: string // ID of the metric (F1, RMSD, kappa_SRME) + value: number // Weight value between 0-1 + label: string // Display name (can include HTML) + description: string // Description of the metric +} + +export interface CombinedMetricConfig { + weights: MetricWeight[] + name: string + description: string +} + +export type CellVal = string | number | undefined | null +export type RowData = Record +export type TableData = RowData[] diff --git a/site/src/routes/+page.svelte b/site/src/routes/+page.svelte index eea57e86..b78220a3 100644 --- a/site/src/routes/+page.svelte +++ b/site/src/routes/+page.svelte @@ -1,24 +1,30 @@ @@ -56,7 +90,7 @@ + + { + metric_config = { + ...metric_config, + weights: weights.map((w) => ({ ...w })), + } + }} + size={260} + /> + {/snippet} @@ -213,13 +253,56 @@ padding: 0 6pt; border-radius: 4pt; } - div.table-controls { + + /* Caption Radar Container Styles */ + figcaption.caption-radar-container { display: flex; flex-wrap: wrap; - gap: 5pt; - place-content: center; - margin: 3pt auto; + align-items: flex-start; + gap: 1em; + background-color: transparent; } + + .radar-container { + width: fit-content; + flex: 0 0 auto; + max-width: 100%; + background: var(--light-bg); + border-radius: 4px; + padding: 0.2em 0.5em; + box-sizing: border-box; + } + + .radar-header { + display: flex; + align-items: center; + gap: 0.5rem; + } + + .metric-name { + font-weight: 600; + margin-right: 0.3em; + } + + .info-icon { + opacity: 0.7; + cursor: help; + } + + .action-button { + background: transparent; + border: 1px solid rgba(255, 255, 255, 0.15); + border-radius: 3px; + cursor: pointer; + padding: 0.15em 0.35em; + font-size: 0.8em; + margin-left: auto; + } + + .action-button:hover { + background: rgba(255, 255, 255, 0.05); + } + figure#metrics-table :global(:is(sub, sup)) { font-size: 0.7em; } diff --git a/site/src/routes/tasks/discovery/+page.svelte b/site/src/routes/tasks/discovery/+page.svelte index 6ae34555..fcd05daf 100644 --- a/site/src/routes/tasks/discovery/+page.svelte +++ b/site/src/routes/tasks/discovery/+page.svelte @@ -1,9 +1,9 @@