From 6d04e9da5c72a4eb286443867bc406fc66c71930 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Tue, 13 Jan 2026 17:02:24 -0800 Subject: [PATCH 01/12] Add famst crate: Fast Approximate Minimum Spanning Tree Implementation of the FAMST algorithm from Almansoori & Telek (2025). Features: - Generic over data type T and distance function - Uses NN-Descent for ANN graph construction - Three-phase approach: ANN graph, component connection, edge refinement - O(dn log n) time complexity, O(dn + kn) space complexity Paper: https://arxiv.org/abs/2507.14261 --- crates/famst/Cargo.toml | 12 + crates/famst/src/lib.rs | 1026 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 1038 insertions(+) create mode 100644 crates/famst/Cargo.toml create mode 100644 crates/famst/src/lib.rs diff --git a/crates/famst/Cargo.toml b/crates/famst/Cargo.toml new file mode 100644 index 0000000..da26573 --- /dev/null +++ b/crates/famst/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "famst" +version = "0.1.0" +edition = "2024" +description = "Fast Approximate Minimum Spanning Tree (FAMST) algorithm" +license = "MIT" + +[dependencies] +rand = "0.8" + +[dev-dependencies] +rand = "0.8" diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs new file mode 100644 index 0000000..6c97bf6 --- /dev/null +++ b/crates/famst/src/lib.rs @@ -0,0 +1,1026 @@ +//! FAMST: Fast Approximate Minimum Spanning Tree +//! +//! Implementation of the FAMST algorithm from: +//! "FAMST: Fast Approximate Minimum Spanning Tree Construction for Large-Scale +//! and High-Dimensional Data" (Almansoori & Telek, 2025) +//! +//! The algorithm uses three phases: +//! 1. ANN graph construction using NN-Descent +//! 2. Component analysis and connection with random edges +//! 3. Iterative edge refinement +//! +//! Generic over data type `T` and distance function. + +use rand::seq::SliceRandom; +use rand::Rng; +use std::collections::{BinaryHeap, HashMap, HashSet}; + +/// An edge in the MST, represented as (node_a, node_b, distance) +#[derive(Debug, Clone)] +pub struct Edge { + pub u: usize, + pub v: usize, + pub distance: f64, +} + +impl Edge { + pub fn new(u: usize, v: usize, distance: f64) -> Self { + Edge { u, v, distance } + } +} + +/// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm +pub struct UnionFind { + parent: Vec, + rank: Vec, +} + +impl UnionFind { + pub fn new(n: usize) -> Self { + UnionFind { + parent: (0..n).collect(), + rank: vec![0; n], + } + } + + pub fn find(&mut self, x: usize) -> usize { + if self.parent[x] != x { + self.parent[x] = self.find(self.parent[x]); // Path compression + } + self.parent[x] + } + + pub fn union(&mut self, x: usize, y: usize) -> bool { + let px = self.find(x); + let py = self.find(y); + if px == py { + return false; + } + // Union by rank + match self.rank[px].cmp(&self.rank[py]) { + std::cmp::Ordering::Less => self.parent[px] = py, + std::cmp::Ordering::Greater => self.parent[py] = px, + std::cmp::Ordering::Equal => { + self.parent[py] = px; + self.rank[px] += 1; + } + } + true + } +} + +/// Approximate Nearest Neighbors graph representation +/// Contains neighbor indices and distances for each point +pub struct AnnGraph { + /// neighbors[i] contains the indices of k nearest neighbors of point i + pub neighbors: Vec>, + /// distances[i] contains the distances to k nearest neighbors of point i + pub distances: Vec>, +} + +impl AnnGraph { + pub fn new(neighbors: Vec>, distances: Vec>) -> Self { + assert_eq!(neighbors.len(), distances.len()); + AnnGraph { + neighbors, + distances, + } + } + + pub fn n(&self) -> usize { + self.neighbors.len() + } +} + +/// FAMST algorithm configuration +pub struct FamstConfig { + /// Number of nearest neighbors (k in k-NN graph) + pub k: usize, + /// Number of random edges per component pair (λ in the paper) + pub lambda: usize, + /// Maximum refinement iterations (0 for unlimited until convergence) + pub max_iterations: usize, + /// Maximum NN-Descent iterations + pub nn_descent_iterations: usize, + /// Sample rate for NN-Descent (fraction of neighbors to sample) + pub nn_descent_sample_rate: f64, +} + +impl Default for FamstConfig { + fn default() -> Self { + FamstConfig { + k: 20, + lambda: 5, + max_iterations: 100, + nn_descent_iterations: 10, + nn_descent_sample_rate: 0.5, + } + } +} + +/// Result of FAMST algorithm +pub struct FamstResult { + /// MST edges + pub edges: Vec, + /// Total weight of the MST + pub total_weight: f64, +} + +/// Main FAMST algorithm implementation +/// +/// Generic over: +/// - `T`: The data type stored at each point +/// - `D`: Distance function `Fn(&T, &T) -> f64` +/// +/// # Arguments +/// * `data` - Slice of data points +/// * `distance_fn` - Function to compute distance between two points +/// * `config` - Algorithm configuration +/// +/// # Returns +/// The approximate MST as a list of edges +pub fn famst(data: &[T], distance_fn: D, config: &FamstConfig) -> FamstResult +where + D: Fn(&T, &T) -> f64, +{ + famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) +} + +/// FAMST with custom RNG for reproducibility +pub fn famst_with_rng( + data: &[T], + distance_fn: D, + config: &FamstConfig, + rng: &mut R, +) -> FamstResult +where + D: Fn(&T, &T) -> f64, + R: Rng, +{ + let n = data.len(); + if n == 0 { + return FamstResult { + edges: vec![], + total_weight: 0.0, + }; + } + if n == 1 { + return FamstResult { + edges: vec![], + total_weight: 0.0, + }; + } + + // Phase 1: Build ANN graph using NN-Descent + let ann_graph = nn_descent(data, &distance_fn, config, rng); + + // Phase 2: Build undirected graph and find connected components + let (undirected_graph, components) = find_components(&ann_graph); + + // If only one component, skip inter-component edge logic + println!("components {}", components.len()); + if components.len() <= 1 { + let edges = extract_mst_from_ann(&ann_graph, n); + let total_weight = edges.iter().map(|e| e.distance).sum(); + return FamstResult { + edges, + total_weight, + }; + } + + // Phase 2 continued: Add random edges between components + let (mut inter_edges, edge_components) = + add_random_edges(data, &components, config.lambda, &distance_fn, rng); + + // Phase 3: Iterative edge refinement + let mut iterations = 0; + loop { + let (refined_edges, changes) = refine_edges( + data, + &undirected_graph, + &components, + &inter_edges, + &edge_components, + &distance_fn, + ); + inter_edges = refined_edges; + + if changes == 0 { + break; + } + + iterations += 1; + if config.max_iterations > 0 && iterations >= config.max_iterations { + break; + } + } + + // Phase 4: Extract MST using Kruskal's algorithm + let edges = extract_mst(&ann_graph, &inter_edges, n); + let total_weight = edges.iter().map(|e| e.distance).sum(); + + FamstResult { + edges, + total_weight, + } +} + +/// A neighbor entry in the k-NN heap (max-heap by distance for easy replacement of farthest) +#[derive(Clone, Copy)] +struct NeighborEntry { + index: usize, + distance: f64, +} + +impl PartialEq for NeighborEntry { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for NeighborEntry {} + +impl PartialOrd for NeighborEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for NeighborEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Max-heap: larger distances have higher priority + self.distance + .partial_cmp(&other.distance) + .unwrap_or(std::cmp::Ordering::Equal) + } +} + +/// NN-Descent algorithm for approximate k-NN graph construction +/// +/// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" +/// by Wei Dong, Charikar Moses, and Kai Li (2011) +fn nn_descent(data: &[T], distance_fn: &D, config: &FamstConfig, rng: &mut R) -> AnnGraph +where + D: Fn(&T, &T) -> f64, + R: Rng, +{ + let n = data.len(); + let k = config.k.min(n - 1); + + if k == 0 || n <= 1 { + return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]); + } + + // Initialize with random neighbors using max-heap for each point + let mut heaps: Vec> = Vec::with_capacity(n); + let mut neighbor_sets: Vec> = vec![HashSet::with_capacity(k); n]; + + for i in 0..n { + let mut heap = BinaryHeap::with_capacity(k); + let mut indices: Vec = (0..n).filter(|&j| j != i).collect(); + indices.shuffle(rng); + + for &j in indices.iter().take(k) { + let d = distance_fn(&data[i], &data[j]); + heap.push(NeighborEntry { + index: j, + distance: d, + }); + neighbor_sets[i].insert(j); + } + heaps.push(heap); + } + + // Build reverse neighbor lists (who has me as a neighbor) + let build_reverse = |neighbor_sets: &[HashSet]| -> Vec> { + let mut reverse: Vec> = vec![HashSet::new(); n]; + for (i, neighbors) in neighbor_sets.iter().enumerate() { + for &j in neighbors { + reverse[j].insert(i); + } + } + reverse + }; + + // NN-Descent iterations + for _ in 0..config.nn_descent_iterations { + let mut updates = 0; + let reverse_neighbors = build_reverse(&neighbor_sets); + + // For each point, explore neighbors of neighbors + for i in 0..n { + // Collect candidates: neighbors and reverse neighbors + let mut candidates: Vec = Vec::new(); + + // Sample from forward neighbors + let forward: Vec = neighbor_sets[i].iter().copied().collect(); + let sample_size = + ((forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); + let mut sampled_forward = forward.clone(); + sampled_forward.shuffle(rng); + sampled_forward.truncate(sample_size); + + // Sample from reverse neighbors + let reverse: Vec = reverse_neighbors[i].iter().copied().collect(); + let sample_size = + ((reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); + let mut sampled_reverse = reverse.clone(); + sampled_reverse.shuffle(rng); + sampled_reverse.truncate(sample_size); + + // Neighbors of neighbors + for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { + for &nn in &neighbor_sets[neighbor] { + if nn != i && !neighbor_sets[i].contains(&nn) { + candidates.push(nn); + } + } + // Also check reverse neighbors of neighbors + for &rn in &reverse_neighbors[neighbor] { + if rn != i && !neighbor_sets[i].contains(&rn) { + candidates.push(rn); + } + } + } + + // Deduplicate candidates + candidates.sort_unstable(); + candidates.dedup(); + + // Try to improve neighbors + for c in candidates { + let d = distance_fn(&data[i], &data[c]); + + // Check if this is better than the worst current neighbor + if let Some(worst) = heaps[i].peek() { + if d < worst.distance { + // Remove worst and add new neighbor + let removed = heaps[i].pop().unwrap(); + neighbor_sets[i].remove(&removed.index); + + heaps[i].push(NeighborEntry { + index: c, + distance: d, + }); + neighbor_sets[i].insert(c); + updates += 1; + } + } + } + } + + // Early termination if no updates + if updates == 0 { + break; + } + } + + // Convert heaps to sorted neighbor lists + let mut neighbors = vec![Vec::with_capacity(k); n]; + let mut distances = vec![Vec::with_capacity(k); n]; + + for (i, heap) in heaps.into_iter().enumerate() { + let mut entries: Vec = heap.into_vec(); + entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + for entry in entries { + neighbors[i].push(entry.index); + distances[i].push(entry.distance); + } + } + + AnnGraph::new(neighbors, distances) +} + +/// Find connected components in the ANN graph using DFS +/// Returns the undirected graph adjacency list and component assignments +fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { + let n = ann_graph.n(); + + // Build undirected graph from directed ANN graph + let mut graph: Vec> = vec![HashSet::new(); n]; + for (i, neighbors) in ann_graph.neighbors.iter().enumerate() { + for &j in neighbors { + graph[i].insert(j); + graph[j].insert(i); + } + } + + // DFS to find components + let mut visited = vec![false; n]; + let mut components: Vec> = Vec::new(); + + for start in 0..n { + if visited[start] { + continue; + } + + let mut component = Vec::new(); + let mut stack = vec![start]; + + while let Some(u) = stack.pop() { + if visited[u] { + continue; + } + visited[u] = true; + component.push(u); + + for &v in &graph[u] { + if !visited[v] { + stack.push(v); + } + } + } + + components.push(component); + } + + (graph, components) +} + +/// Add random edges between components (Algorithm 3 in the paper) +fn add_random_edges( + data: &[T], + components: &[Vec], + lambda: usize, + distance_fn: &D, + rng: &mut R, +) -> (Vec, Vec<(usize, usize)>) +where + D: Fn(&T, &T) -> f64, + R: Rng, +{ + let t = components.len(); + let mut edges = Vec::new(); + let mut edge_components = Vec::new(); + + let lambda_sq = lambda * lambda; + + for i in 0..t { + for j in (i + 1)..t { + let mut candidates: Vec = Vec::with_capacity(lambda_sq); + + // Generate λ² candidate edges + for _ in 0..lambda_sq { + let u = *components[i].choose(rng).unwrap(); + let v = *components[j].choose(rng).unwrap(); + let d = distance_fn(&data[u], &data[v]); + candidates.push(Edge::new(u, v, d)); + } + + // Sort by distance and take top λ + candidates.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + for edge in candidates.into_iter().take(lambda) { + edges.push(edge); + edge_components.push((i, j)); + } + } + } + + (edges, edge_components) +} + +/// Refine inter-component edges (Algorithm 4 in the paper) +fn refine_edges( + data: &[T], + undirected_graph: &[HashSet], + components: &[Vec], + edges: &[Edge], + edge_components: &[(usize, usize)], + distance_fn: &D, +) -> (Vec, usize) +where + D: Fn(&T, &T) -> f64, +{ + // Build component membership lookup + let mut node_to_component: HashMap = HashMap::new(); + for (comp_idx, component) in components.iter().enumerate() { + for &node in component { + node_to_component.insert(node, comp_idx); + } + } + + // Build component node sets for quick lookup + let component_sets: Vec> = components + .iter() + .map(|c| c.iter().copied().collect()) + .collect(); + + let mut refined_edges = Vec::with_capacity(edges.len()); + let mut changes = 0; + + for (edge, &(ci, cj)) in edges.iter().zip(edge_components.iter()) { + let mut best_u = edge.u; + let mut best_v = edge.v; + let mut best_d = edge.distance; + + // Get neighbors of u that are in component ci + let neighbors_u: Vec = undirected_graph[edge.u] + .iter() + .filter(|&&n| component_sets[ci].contains(&n)) + .copied() + .collect(); + + // Try to find better u from neighbors + for u_prime in neighbors_u { + if u_prime == edge.v { + continue; + } + let d_prime = distance_fn(&data[u_prime], &data[best_v]); + if d_prime < best_d { + best_u = u_prime; + best_d = d_prime; + } + } + + // Get neighbors of v that are in component cj + let neighbors_v: Vec = undirected_graph[edge.v] + .iter() + .filter(|&&n| component_sets[cj].contains(&n)) + .copied() + .collect(); + + // Try to find better v from neighbors (using updated best_u) + for v_prime in neighbors_v { + if v_prime == edge.u { + continue; + } + let d_prime = distance_fn(&data[best_u], &data[v_prime]); + if d_prime < best_d { + best_v = v_prime; + best_d = d_prime; + } + } + + if best_u != edge.u || best_v != edge.v { + changes += 1; + } + + refined_edges.push(Edge::new(best_u, best_v, best_d)); + } + + (refined_edges, changes) +} + +/// Extract MST using Kruskal's algorithm on the connected ANN graph +fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec { + // Collect all edges from ANN graph + let mut all_edges: Vec = Vec::new(); + + for (i, (neighbors, distances)) in ann_graph + .neighbors + .iter() + .zip(ann_graph.distances.iter()) + .enumerate() + { + for (&j, &d) in neighbors.iter().zip(distances.iter()) { + // Only add edge once (i < j) + if i < j { + all_edges.push(Edge::new(i, j, d)); + } + } + } + + // Add edges where j < i (reverse direction in ANN graph) + for (i, (neighbors, distances)) in ann_graph + .neighbors + .iter() + .zip(ann_graph.distances.iter()) + .enumerate() + { + for (&j, &d) in neighbors.iter().zip(distances.iter()) { + if j < i { + // Check if this edge isn't already added + all_edges.push(Edge::new(j, i, d)); + } + } + } + + // Add inter-component edges + for edge in inter_edges { + all_edges.push(edge.clone()); + } + + // Deduplicate edges + let mut edge_set: HashMap<(usize, usize), f64> = HashMap::new(); + for edge in all_edges { + let key = if edge.u < edge.v { + (edge.u, edge.v) + } else { + (edge.v, edge.u) + }; + edge_set + .entry(key) + .and_modify(|d| { + if edge.distance < *d { + *d = edge.distance + } + }) + .or_insert(edge.distance); + } + + let mut edges: Vec = edge_set + .into_iter() + .map(|((u, v), d)| Edge::new(u, v, d)) + .collect(); + + // Sort edges by weight + edges.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + // Kruskal's algorithm + let mut uf = UnionFind::new(n); + let mut mst_edges = Vec::with_capacity(n - 1); + + for edge in edges { + if uf.union(edge.u, edge.v) { + mst_edges.push(edge); + if mst_edges.len() == n - 1 { + break; + } + } + } + + mst_edges +} + +/// Extract MST when graph is already connected (single component) +fn extract_mst_from_ann(ann_graph: &AnnGraph, n: usize) -> Vec { + extract_mst(ann_graph, &[], n) +} + +/// Euclidean distance for slices of f64 +pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() +} + +/// Manhattan distance for slices of f64 +pub fn manhattan_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + #[test] + fn test_union_find() { + let mut uf = UnionFind::new(5); + assert!(uf.union(0, 1)); + assert!(uf.union(2, 3)); + assert!(!uf.union(0, 1)); // Already same set + assert!(uf.union(1, 2)); + assert_eq!(uf.find(0), uf.find(3)); + } + + #[test] + fn test_simple_mst() { + // Simple 2D points forming a triangle + let points: Vec> = vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.5, 0.866], // Equilateral triangle + ]; + + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 2, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 2); // MST has n-1 edges + } + + #[test] + fn test_line_points() { + // Points on a line + let points: Vec> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]]; + + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 2, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 4); + // Total weight should be 4.0 (1+1+1+1) + assert!((result.total_weight - 4.0).abs() < 1e-10); + } + + #[test] + fn test_disconnected_components() { + // Two clusters far apart + let points: Vec> = vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.5, 0.5], + vec![100.0, 100.0], + vec![101.0, 100.0], + vec![100.5, 100.5], + ]; + + // k=1 will likely create disconnected components + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 1, + lambda: 3, + max_iterations: 10, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 5); // MST has n-1 edges + } + + #[test] + fn test_custom_distance() { + // Test with Manhattan distance + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; + + let distance = |a: &Vec, b: &Vec| manhattan_distance(a, b); + let config = FamstConfig { + k: 2, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 2); + // Manhattan distance from (0,0) to (1,1) is 2, and (1,1) to (2,2) is 2 + assert!((result.total_weight - 4.0).abs() < 1e-10); + } + + #[test] + fn test_generic_data_type() { + // Test with a custom struct + #[derive(Clone)] + struct Point3D { + x: f64, + y: f64, + z: f64, + } + + fn point_distance(a: &Point3D, b: &Point3D) -> f64 { + ((a.x - b.x).powi(2) + (a.y - b.y).powi(2) + (a.z - b.z).powi(2)).sqrt() + } + + let points = vec![ + Point3D { + x: 0.0, + y: 0.0, + z: 0.0, + }, + Point3D { + x: 1.0, + y: 0.0, + z: 0.0, + }, + Point3D { + x: 0.0, + y: 1.0, + z: 0.0, + }, + Point3D { + x: 0.0, + y: 0.0, + z: 1.0, + }, + ]; + + let config = FamstConfig { + k: 3, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, point_distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 3); + } + + #[test] + fn test_multiple_clusters() { + // Create 5 well-separated clusters to force multiple components with small k + // Each cluster is a tight group of points, clusters are far apart + use rand::distributions::{Distribution, Uniform}; + + let mut rng = StdRng::seed_from_u64(77777); + let noise = Uniform::new(-0.5, 0.5); + + let cluster_centers = vec![ + vec![0.0, 0.0], + vec![100.0, 0.0], + vec![0.0, 100.0], + vec![100.0, 100.0], + vec![50.0, 50.0], + ]; + + let points_per_cluster = 20; + let mut points: Vec> = Vec::new(); + + for center in &cluster_centers { + for _ in 0..points_per_cluster { + let point = vec![ + center[0] + noise.sample(&mut rng), + center[1] + noise.sample(&mut rng), + ]; + points.push(point); + } + } + + let n = points.len(); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + + // Use small k to create disconnected components + // With k=3 and 20 points per cluster spread over 5 clusters, + // each point's 3 nearest neighbors will be in its own cluster + let config = FamstConfig { + k: 3, + lambda: 5, + max_iterations: 50, + nn_descent_iterations: 20, + nn_descent_sample_rate: 1.0, // Full sampling for small dataset + }; + + let mut famst_rng = StdRng::seed_from_u64(88888); + let result = famst_with_rng(&points, distance, &config, &mut famst_rng); + + // Should produce a valid MST with n-1 edges + assert_eq!(result.edges.len(), n - 1, "MST should have n-1 edges"); + + // Verify connectivity: all nodes should be reachable + let mut uf = UnionFind::new(n); + for edge in &result.edges { + uf.union(edge.u, edge.v); + } + // Check all nodes are in the same component + let root = uf.find(0); + for i in 1..n { + assert_eq!(uf.find(i), root, "All nodes should be connected in the MST"); + } + + // Compare with exact MST + let exact_weight = exact_mst_weight(&points, distance); + let error_ratio = (result.total_weight - exact_weight) / exact_weight; + + println!( + "Multi-cluster test: Exact MST weight: {:.4}, FAMST weight: {:.4}, error: {:.2}%", + exact_weight, + result.total_weight, + error_ratio * 100.0 + ); + + // Should be reasonably close (within 15% given the challenging setup) + assert!( + error_ratio < 0.15, + "FAMST error should be < 15%, got {:.2}%", + error_ratio * 100.0 + ); + } + + /// Compute exact MST using Kruskal's algorithm on complete graph + fn exact_mst_weight(data: &[T], distance_fn: D) -> f64 + where + D: Fn(&T, &T) -> f64, + { + let n = data.len(); + if n <= 1 { + return 0.0; + } + + // Build all edges + let mut edges: Vec<(usize, usize, f64)> = Vec::with_capacity(n * (n - 1) / 2); + for i in 0..n { + for j in (i + 1)..n { + let d = distance_fn(&data[i], &data[j]); + edges.push((i, j, d)); + } + } + + // Sort by weight + edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap()); + + // Kruskal's algorithm + let mut uf = UnionFind::new(n); + let mut total_weight = 0.0; + let mut edge_count = 0; + + for (u, v, w) in edges { + if uf.union(u, v) { + total_weight += w; + edge_count += 1; + if edge_count == n - 1 { + break; + } + } + } + + total_weight + } + + #[test] + #[ignore] // Run with: cargo test large_scale -- --ignored --nocapture + fn test_large_scale_vs_exact() { + use rand::distributions::{Distribution, Uniform}; + + const N: usize = 1_000_000; + const DIM: usize = 10; + + println!("Generating {} random {}-dimensional points...", N, DIM); + let mut rng = StdRng::seed_from_u64(12345); + let dist = Uniform::new(0.0, 1000.0); + + let points: Vec> = (0..N) + .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) + .collect(); + + println!("Running FAMST with NN-Descent..."); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 20, + lambda: 5, + max_iterations: 100, + nn_descent_iterations: 10, + nn_descent_sample_rate: 0.5, + }; + let mut famst_rng = StdRng::seed_from_u64(54321); + let start = std::time::Instant::now(); + let result = famst_with_rng(&points, distance, &config, &mut famst_rng); + let famst_time = start.elapsed(); + + println!("FAMST completed in {:?}", famst_time); + println!("FAMST MST weight: {:.4}", result.total_weight); + println!("FAMST MST edges: {}", result.edges.len()); + + assert_eq!(result.edges.len(), N - 1, "MST should have n-1 edges"); + } + + #[test] + fn test_medium_scale_vs_exact() { + use rand::distributions::{Distribution, Uniform}; + + const N: usize = 5000; + const DIM: usize = 5; + + let mut rng = StdRng::seed_from_u64(99999); + let dist = Uniform::new(0.0, 100.0); + + let points: Vec> = (0..N) + .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) + .collect(); + + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + + // Compute exact MST + let exact_weight = exact_mst_weight(&points, distance); + + // Compute approximate MST with FAMST + let config = FamstConfig { + k: 15, + lambda: 5, + max_iterations: 100, + nn_descent_iterations: 15, + nn_descent_sample_rate: 0.5, + }; + let mut famst_rng = StdRng::seed_from_u64(11111); + let result = famst_with_rng(&points, distance, &config, &mut famst_rng); + + assert_eq!(result.edges.len(), N - 1); + + // FAMST should produce a weight >= exact (it's an approximation) + // and should be reasonably close (within a few percent for good k) + let error_ratio = (result.total_weight - exact_weight) / exact_weight; + println!( + "Exact MST weight: {:.4}, FAMST weight: {:.4}, error: {:.2}%", + exact_weight, + result.total_weight, + error_ratio * 100.0 + ); + + // The approximation should be within 10% for this setup with NN-Descent + assert!( + error_ratio >= 0.0, + "FAMST weight should be >= exact MST weight" + ); + assert!( + error_ratio < 0.10, + "FAMST error should be < 10%, got {:.2}%", + error_ratio * 100.0 + ); + } +} From c09952fffb25c7063a0774ba62117e2f73bdcf5e Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 14:52:08 -0800 Subject: [PATCH 02/12] =?UTF-8?q?Fix=20O(n=C2=B2)=20initialization=20bug?= =?UTF-8?q?=20in=20NN-Descent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use Floyd's algorithm to sample k random neighbors in O(k) time instead of allocating and shuffling all n indices per point. --- crates/famst/src/lib.rs | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 6c97bf6..296e8dd 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -277,16 +277,33 @@ where for i in 0..n { let mut heap = BinaryHeap::with_capacity(k); - let mut indices: Vec = (0..n).filter(|&j| j != i).collect(); - indices.shuffle(rng); - - for &j in indices.iter().take(k) { - let d = distance_fn(&data[i], &data[j]); - heap.push(NeighborEntry { - index: j, - distance: d, - }); - neighbor_sets[i].insert(j); + + // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) + // https://fermatslibrary.com/s/a-sample-of-brilliance + // This selects k distinct elements from 0..n, excluding i + let effective_n = n - 1; // exclude self + let range_start = effective_n.saturating_sub(k); + for t in range_start..effective_n { + let j = rng.gen_range(0..=t); + // Map j to actual index, skipping i + let actual_j = if j >= i { j + 1 } else { j }; + + if !neighbor_sets[i].insert(actual_j) { + // j was already selected, so add t instead + let actual_t = if t >= i { t + 1 } else { t }; + neighbor_sets[i].insert(actual_t); + let d = distance_fn(&data[i], &data[actual_t]); + heap.push(NeighborEntry { + index: actual_t, + distance: d, + }); + } else { + let d = distance_fn(&data[i], &data[actual_j]); + heap.push(NeighborEntry { + index: actual_j, + distance: d, + }); + } } heaps.push(heap); } From 4ad3eb2033c1e71476960097498cd1a5ad2ae521 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:34:00 -0800 Subject: [PATCH 03/12] Unify duplicate loops in extract_mst --- crates/famst/src/lib.rs | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 296e8dd..71fb291 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -592,25 +592,8 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec Date: Wed, 14 Jan 2026 15:36:30 -0800 Subject: [PATCH 04/12] Remove unnecessary edge deduplication in extract_mst Kruskal's algorithm naturally skips duplicate edges via union-find. --- crates/famst/src/lib.rs | 34 ++++------------------------------ 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 71fb291..f46349a 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -583,7 +583,7 @@ where /// Extract MST using Kruskal's algorithm on the connected ANN graph fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec { // Collect all edges from ANN graph - let mut all_edges: Vec = Vec::new(); + let mut edges: Vec = Vec::new(); for (i, (neighbors, distances)) in ann_graph .neighbors @@ -592,43 +592,17 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec = HashMap::new(); - for edge in all_edges { - let key = if edge.u < edge.v { - (edge.u, edge.v) - } else { - (edge.v, edge.u) - }; - edge_set - .entry(key) - .and_modify(|d| { - if edge.distance < *d { - *d = edge.distance - } - }) - .or_insert(edge.distance); - } - - let mut edges: Vec = edge_set - .into_iter() - .map(|((u, v), d)| Edge::new(u, v, d)) - .collect(); + edges.extend(inter_edges.iter().cloned()); // Sort edges by weight edges.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); - // Kruskal's algorithm + // Kruskal's algorithm (naturally handles duplicate edges) let mut uf = UnionFind::new(n); let mut mst_edges = Vec::with_capacity(n - 1); From c14f0cb9a8072506c3b7303bbe74680b3d10f374 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:38:46 -0800 Subject: [PATCH 05/12] Unify n==0 and n==1 cases, remove debug println --- crates/famst/src/lib.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index f46349a..8d06bbe 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -158,13 +158,7 @@ where R: Rng, { let n = data.len(); - if n == 0 { - return FamstResult { - edges: vec![], - total_weight: 0.0, - }; - } - if n == 1 { + if n <= 1 { return FamstResult { edges: vec![], total_weight: 0.0, @@ -178,7 +172,6 @@ where let (undirected_graph, components) = find_components(&ann_graph); // If only one component, skip inter-component edge logic - println!("components {}", components.len()); if components.len() <= 1 { let edges = extract_mst_from_ann(&ann_graph, n); let total_weight = edges.iter().map(|e| e.distance).sum(); From 4ccefd923a4cb45f69ddff9b1d68051b4f959e8c Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:39:52 -0800 Subject: [PATCH 06/12] Add tests for empty and single-point inputs --- crates/famst/src/lib.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 8d06bbe..b5a0209 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -636,6 +636,24 @@ mod tests { use rand::rngs::StdRng; use rand::SeedableRng; + #[test] + fn test_empty_input() { + let points: Vec> = vec![]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let result = famst(&points, distance, &FamstConfig::default()); + assert_eq!(result.edges.len(), 0); + assert_eq!(result.total_weight, 0.0); + } + + #[test] + fn test_single_point() { + let points: Vec> = vec![vec![1.0, 2.0, 3.0]]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let result = famst(&points, distance, &FamstConfig::default()); + assert_eq!(result.edges.len(), 0); + assert_eq!(result.total_weight, 0.0); + } + #[test] fn test_union_find() { let mut uf = UnionFind::new(5); From 60d85f226ba97fbcff00eef02748b25d709a2c82 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:41:35 -0800 Subject: [PATCH 07/12] Add test for k >= n case --- crates/famst/src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index b5a0209..932aa4d 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -654,6 +654,20 @@ mod tests { assert_eq!(result.total_weight, 0.0); } + #[test] + fn test_k_greater_than_n() { + // 3 points but k=20 (default), so k >= n + let points: Vec> = vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.0, 1.0], + ]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig::default(); // k=20 > n=3 + let result = famst(&points, distance, &config); + assert_eq!(result.edges.len(), 2); // MST has n-1 edges + } + #[test] fn test_union_find() { let mut uf = UnionFind::new(5); From e37f359800641ecbc56552d8e0392d1caf6c41a9 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 11:06:47 -0600 Subject: [PATCH 08/12] Manual tweaks. --- crates/famst/src/lib.rs | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 932aa4d..b147a04 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -146,7 +146,7 @@ where famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) } -/// FAMST with custom RNG for reproducibility +/// FAMST with custom RNG. (We use a seeded RNG in tests for reproducibility.) pub fn famst_with_rng( data: &[T], distance_fn: D, @@ -616,26 +616,26 @@ fn extract_mst_from_ann(ann_graph: &AnnGraph, n: usize) -> Vec { extract_mst(ann_graph, &[], n) } -/// Euclidean distance for slices of f64 -pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { - a.iter() - .zip(b.iter()) - .map(|(x, y)| (x - y).powi(2)) - .sum::() - .sqrt() -} - -/// Manhattan distance for slices of f64 -pub fn manhattan_distance(a: &[f64], b: &[f64]) -> f64 { - a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() -} - #[cfg(test)] mod tests { use super::*; use rand::rngs::StdRng; use rand::SeedableRng; + /// Manhattan distance for slices of f64 + pub fn manhattan_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() + } + + /// Euclidean distance for slices of f64 + pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() + } + #[test] fn test_empty_input() { let points: Vec> = vec![]; @@ -657,11 +657,7 @@ mod tests { #[test] fn test_k_greater_than_n() { // 3 points but k=20 (default), so k >= n - let points: Vec> = vec![ - vec![0.0, 0.0], - vec![1.0, 0.0], - vec![0.0, 1.0], - ]; + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]]; let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig::default(); // k=20 > n=3 let result = famst(&points, distance, &config); From b3a257b61d5adbe602af1bf40da64a922436ea3a Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 11:32:36 -0600 Subject: [PATCH 09/12] Replace HashSets with sorted Vecs for memory efficiency This reduces memory usage by ~5x for large graphs: - neighbor_lists in nn_descent: HashSet -> sorted Vec - reverse_neighbors in nn_descent: HashSet -> Vec (already sorted by construction) - graph in find_components: HashSet -> sorted Vec - node_to_component in refine_edges: HashMap -> Vec (O(1) indexed lookup) - Removed component_sets HashSets entirely For n=1 billion, k=20, this saves ~2.5 TB of memory. --- crates/famst/src/lib.rs | 142 +++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index b147a04..16f0f65 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -13,7 +13,7 @@ use rand::seq::SliceRandom; use rand::Rng; -use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::collections::BinaryHeap; /// An edge in the MST, represented as (node_a, node_b, distance) #[derive(Debug, Clone)] @@ -264,9 +264,33 @@ where return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]); } + // Helper: check if sorted vec contains value + fn sorted_contains(v: &[usize], x: usize) -> bool { + v.binary_search(&x).is_ok() + } + + // Helper: insert into sorted vec, returns true if inserted (was not present) + fn sorted_insert(v: &mut Vec, x: usize) -> bool { + match v.binary_search(&x) { + Ok(_) => false, + Err(pos) => { + v.insert(pos, x); + true + } + } + } + + // Helper: remove from sorted vec + fn sorted_remove(v: &mut Vec, x: usize) { + if let Ok(pos) = v.binary_search(&x) { + v.remove(pos); + } + } + // Initialize with random neighbors using max-heap for each point + // neighbor_lists[i] is kept sorted by index for O(log k) membership tests let mut heaps: Vec> = Vec::with_capacity(n); - let mut neighbor_sets: Vec> = vec![HashSet::with_capacity(k); n]; + let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; for i in 0..n { let mut heap = BinaryHeap::with_capacity(k); @@ -281,10 +305,10 @@ where // Map j to actual index, skipping i let actual_j = if j >= i { j + 1 } else { j }; - if !neighbor_sets[i].insert(actual_j) { + if !sorted_insert(&mut neighbor_lists[i], actual_j) { // j was already selected, so add t instead let actual_t = if t >= i { t + 1 } else { t }; - neighbor_sets[i].insert(actual_t); + sorted_insert(&mut neighbor_lists[i], actual_t); let d = distance_fn(&data[i], &data[actual_t]); heap.push(NeighborEntry { index: actual_t, @@ -302,20 +326,22 @@ where } // Build reverse neighbor lists (who has me as a neighbor) - let build_reverse = |neighbor_sets: &[HashSet]| -> Vec> { - let mut reverse: Vec> = vec![HashSet::new(); n]; - for (i, neighbors) in neighbor_sets.iter().enumerate() { + // Returns sorted vecs for each point + let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { + let mut reverse: Vec> = vec![Vec::new(); n]; + for (i, neighbors) in neighbor_lists.iter().enumerate() { for &j in neighbors { - reverse[j].insert(i); + reverse[j].push(i); } } + // Sort each reverse list (they're built in order of i, so already sorted) reverse }; // NN-Descent iterations for _ in 0..config.nn_descent_iterations { let mut updates = 0; - let reverse_neighbors = build_reverse(&neighbor_sets); + let reverse_neighbors = build_reverse(&neighbor_lists); // For each point, explore neighbors of neighbors for i in 0..n { @@ -323,31 +349,29 @@ where let mut candidates: Vec = Vec::new(); // Sample from forward neighbors - let forward: Vec = neighbor_sets[i].iter().copied().collect(); + let mut sampled_forward = neighbor_lists[i].clone(); let sample_size = - ((forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); - let mut sampled_forward = forward.clone(); + ((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); sampled_forward.shuffle(rng); sampled_forward.truncate(sample_size); // Sample from reverse neighbors - let reverse: Vec = reverse_neighbors[i].iter().copied().collect(); + let mut sampled_reverse = reverse_neighbors[i].clone(); let sample_size = - ((reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); - let mut sampled_reverse = reverse.clone(); + ((sampled_reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); sampled_reverse.shuffle(rng); sampled_reverse.truncate(sample_size); // Neighbors of neighbors for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for &nn in &neighbor_sets[neighbor] { - if nn != i && !neighbor_sets[i].contains(&nn) { + for &nn in &neighbor_lists[neighbor] { + if nn != i && !sorted_contains(&neighbor_lists[i], nn) { candidates.push(nn); } } // Also check reverse neighbors of neighbors for &rn in &reverse_neighbors[neighbor] { - if rn != i && !neighbor_sets[i].contains(&rn) { + if rn != i && !sorted_contains(&neighbor_lists[i], rn) { candidates.push(rn); } } @@ -366,13 +390,13 @@ where if d < worst.distance { // Remove worst and add new neighbor let removed = heaps[i].pop().unwrap(); - neighbor_sets[i].remove(&removed.index); + sorted_remove(&mut neighbor_lists[i], removed.index); heaps[i].push(NeighborEntry { index: c, distance: d, }); - neighbor_sets[i].insert(c); + sorted_insert(&mut neighbor_lists[i], c); updates += 1; } } @@ -403,18 +427,23 @@ where } /// Find connected components in the ANN graph using DFS -/// Returns the undirected graph adjacency list and component assignments -fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { +/// Returns the undirected graph adjacency list (sorted vecs) and component assignments +fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { let n = ann_graph.n(); - // Build undirected graph from directed ANN graph - let mut graph: Vec> = vec![HashSet::new(); n]; + // Build undirected graph from directed ANN graph using sorted vecs + let mut graph: Vec> = vec![Vec::new(); n]; for (i, neighbors) in ann_graph.neighbors.iter().enumerate() { for &j in neighbors { - graph[i].insert(j); - graph[j].insert(i); + graph[i].push(j); + graph[j].push(i); } } + // Sort and deduplicate each adjacency list + for adj in &mut graph { + adj.sort_unstable(); + adj.dedup(); + } // DFS to find components let mut visited = vec![false; n]; @@ -494,7 +523,7 @@ where /// Refine inter-component edges (Algorithm 4 in the paper) fn refine_edges( data: &[T], - undirected_graph: &[HashSet], + undirected_graph: &[Vec], components: &[Vec], edges: &[Edge], edge_components: &[(usize, usize)], @@ -503,20 +532,16 @@ fn refine_edges( where D: Fn(&T, &T) -> f64, { - // Build component membership lookup - let mut node_to_component: HashMap = HashMap::new(); + let n = data.len(); + + // Build component membership lookup (simple vec, O(1) lookup) + let mut node_to_component: Vec = vec![0; n]; for (comp_idx, component) in components.iter().enumerate() { for &node in component { - node_to_component.insert(node, comp_idx); + node_to_component[node] = comp_idx; } } - // Build component node sets for quick lookup - let component_sets: Vec> = components - .iter() - .map(|c| c.iter().copied().collect()) - .collect(); - let mut refined_edges = Vec::with_capacity(edges.len()); let mut changes = 0; @@ -526,40 +551,25 @@ where let mut best_d = edge.distance; // Get neighbors of u that are in component ci - let neighbors_u: Vec = undirected_graph[edge.u] - .iter() - .filter(|&&n| component_sets[ci].contains(&n)) - .copied() - .collect(); - - // Try to find better u from neighbors - for u_prime in neighbors_u { - if u_prime == edge.v { - continue; - } - let d_prime = distance_fn(&data[u_prime], &data[best_v]); - if d_prime < best_d { - best_u = u_prime; - best_d = d_prime; + // (undirected_graph is sorted, so we can iterate directly) + for &u_prime in &undirected_graph[edge.u] { + if node_to_component[u_prime] == ci && u_prime != edge.v { + let d_prime = distance_fn(&data[u_prime], &data[best_v]); + if d_prime < best_d { + best_u = u_prime; + best_d = d_prime; + } } } // Get neighbors of v that are in component cj - let neighbors_v: Vec = undirected_graph[edge.v] - .iter() - .filter(|&&n| component_sets[cj].contains(&n)) - .copied() - .collect(); - - // Try to find better v from neighbors (using updated best_u) - for v_prime in neighbors_v { - if v_prime == edge.u { - continue; - } - let d_prime = distance_fn(&data[best_u], &data[v_prime]); - if d_prime < best_d { - best_v = v_prime; - best_d = d_prime; + for &v_prime in &undirected_graph[edge.v] { + if node_to_component[v_prime] == cj && v_prime != edge.u { + let d_prime = distance_fn(&data[best_u], &data[v_prime]); + if d_prime < best_d { + best_v = v_prime; + best_d = d_prime; + } } } From 12c98c9cf36d58a05bd7de17bcbd19479993206d Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 12:13:00 -0600 Subject: [PATCH 10/12] Refactor AnnGraph to use single flat allocation - Store neighbors as flat Vec instead of Vec> + Vec> - Neighbor struct combines index and distance (reusing existing NeighborEntry) - Access via ann_graph.neighbors(i) returns &[Neighbor] slice - Eliminates n pointer indirections and reduces memory fragmentation --- crates/famst/src/lib.rs | 95 +++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 16f0f65..335ad06 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -70,25 +70,34 @@ impl UnionFind { } /// Approximate Nearest Neighbors graph representation -/// Contains neighbor indices and distances for each point +/// Stored as a flat n×k matrix of Neighbor entries pub struct AnnGraph { - /// neighbors[i] contains the indices of k nearest neighbors of point i - pub neighbors: Vec>, - /// distances[i] contains the distances to k nearest neighbors of point i - pub distances: Vec>, + /// Flat storage: data[i*k..(i+1)*k] contains k neighbors of point i + data: Vec, + /// Number of points + n: usize, + /// Number of neighbors per point + k: usize, } impl AnnGraph { - pub fn new(neighbors: Vec>, distances: Vec>) -> Self { - assert_eq!(neighbors.len(), distances.len()); - AnnGraph { - neighbors, - distances, - } + pub fn new(n: usize, k: usize, data: Vec) -> Self { + assert_eq!(data.len(), n * k); + AnnGraph { data, n, k } } pub fn n(&self) -> usize { - self.neighbors.len() + self.n + } + + pub fn k(&self) -> usize { + self.k + } + + /// Get the neighbors of point i + pub fn neighbors(&self, i: usize) -> &[Neighbor] { + let start = i * self.k; + &self.data[start..start + self.k] } } @@ -218,28 +227,29 @@ where } } -/// A neighbor entry in the k-NN heap (max-heap by distance for easy replacement of farthest) -#[derive(Clone, Copy)] -struct NeighborEntry { - index: usize, - distance: f64, +/// A neighbor entry: node index and distance +/// Used both in the k-NN heap and in the final AnnGraph +#[derive(Debug, Clone, Copy)] +pub struct Neighbor { + pub index: usize, + pub distance: f64, } -impl PartialEq for NeighborEntry { +impl PartialEq for Neighbor { fn eq(&self, other: &Self) -> bool { self.distance == other.distance } } -impl Eq for NeighborEntry {} +impl Eq for Neighbor {} -impl PartialOrd for NeighborEntry { +impl PartialOrd for Neighbor { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for NeighborEntry { +impl Ord for Neighbor { fn cmp(&self, other: &Self) -> std::cmp::Ordering { // Max-heap: larger distances have higher priority self.distance @@ -261,7 +271,7 @@ where let k = config.k.min(n - 1); if k == 0 || n <= 1 { - return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]); + return AnnGraph::new(n, 0, vec![]); } // Helper: check if sorted vec contains value @@ -289,7 +299,7 @@ where // Initialize with random neighbors using max-heap for each point // neighbor_lists[i] is kept sorted by index for O(log k) membership tests - let mut heaps: Vec> = Vec::with_capacity(n); + let mut heaps: Vec> = Vec::with_capacity(n); let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; for i in 0..n { @@ -310,13 +320,13 @@ where let actual_t = if t >= i { t + 1 } else { t }; sorted_insert(&mut neighbor_lists[i], actual_t); let d = distance_fn(&data[i], &data[actual_t]); - heap.push(NeighborEntry { + heap.push(Neighbor { index: actual_t, distance: d, }); } else { let d = distance_fn(&data[i], &data[actual_j]); - heap.push(NeighborEntry { + heap.push(Neighbor { index: actual_j, distance: d, }); @@ -392,7 +402,7 @@ where let removed = heaps[i].pop().unwrap(); sorted_remove(&mut neighbor_lists[i], removed.index); - heaps[i].push(NeighborEntry { + heaps[i].push(Neighbor { index: c, distance: d, }); @@ -409,21 +419,16 @@ where } } - // Convert heaps to sorted neighbor lists - let mut neighbors = vec![Vec::with_capacity(k); n]; - let mut distances = vec![Vec::with_capacity(k); n]; + // Convert heaps to flat neighbor array sorted by distance + let mut result_data = Vec::with_capacity(n * k); - for (i, heap) in heaps.into_iter().enumerate() { - let mut entries: Vec = heap.into_vec(); + for heap in heaps { + let mut entries: Vec = heap.into_vec(); entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); - - for entry in entries { - neighbors[i].push(entry.index); - distances[i].push(entry.distance); - } + result_data.extend(entries); } - AnnGraph::new(neighbors, distances) + AnnGraph::new(n, k, result_data) } /// Find connected components in the ANN graph using DFS @@ -433,8 +438,9 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { // Build undirected graph from directed ANN graph using sorted vecs let mut graph: Vec> = vec![Vec::new(); n]; - for (i, neighbors) in ann_graph.neighbors.iter().enumerate() { - for &j in neighbors { + for i in 0..n { + for neighbor in ann_graph.neighbors(i) { + let j = neighbor.index; graph[i].push(j); graph[j].push(i); } @@ -588,14 +594,9 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec = Vec::new(); - for (i, (neighbors, distances)) in ann_graph - .neighbors - .iter() - .zip(ann_graph.distances.iter()) - .enumerate() - { - for (&j, &d) in neighbors.iter().zip(distances.iter()) { - edges.push(Edge::new(i, j, d)); + for i in 0..n { + for neighbor in ann_graph.neighbors(i) { + edges.push(Edge::new(i, neighbor.index, neighbor.distance)); } } From 2561f992e3e750a655833b80f1eb8351b857cf69 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 12:20:22 -0600 Subject: [PATCH 11/12] Make internal types private Only expose the public API: famst, famst_with_rng, FamstConfig, FamstResult, Edge. UnionFind, AnnGraph, and Neighbor are now private implementation details. --- crates/famst/src/lib.rs | 79 ++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 335ad06..46e437d 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -30,27 +30,27 @@ impl Edge { } /// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm -pub struct UnionFind { +struct UnionFind { parent: Vec, rank: Vec, } impl UnionFind { - pub fn new(n: usize) -> Self { + fn new(n: usize) -> Self { UnionFind { parent: (0..n).collect(), rank: vec![0; n], } } - pub fn find(&mut self, x: usize) -> usize { + fn find(&mut self, x: usize) -> usize { if self.parent[x] != x { self.parent[x] = self.find(self.parent[x]); // Path compression } self.parent[x] } - pub fn union(&mut self, x: usize, y: usize) -> bool { + fn union(&mut self, x: usize, y: usize) -> bool { let px = self.find(x); let py = self.find(y); if px == py { @@ -69,9 +69,39 @@ impl UnionFind { } } +/// A neighbor entry: node index and distance +#[derive(Debug, Clone, Copy)] +struct Neighbor { + index: usize, + distance: f64, +} + +impl PartialEq for Neighbor { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for Neighbor {} + +impl PartialOrd for Neighbor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Max-heap: larger distances have higher priority + self.distance + .partial_cmp(&other.distance) + .unwrap_or(std::cmp::Ordering::Equal) + } +} + /// Approximate Nearest Neighbors graph representation /// Stored as a flat n×k matrix of Neighbor entries -pub struct AnnGraph { +struct AnnGraph { /// Flat storage: data[i*k..(i+1)*k] contains k neighbors of point i data: Vec, /// Number of points @@ -81,21 +111,21 @@ pub struct AnnGraph { } impl AnnGraph { - pub fn new(n: usize, k: usize, data: Vec) -> Self { + fn new(n: usize, k: usize, data: Vec) -> Self { assert_eq!(data.len(), n * k); AnnGraph { data, n, k } } - pub fn n(&self) -> usize { + fn n(&self) -> usize { self.n } - pub fn k(&self) -> usize { + fn k(&self) -> usize { self.k } /// Get the neighbors of point i - pub fn neighbors(&self, i: usize) -> &[Neighbor] { + fn neighbors(&self, i: usize) -> &[Neighbor] { let start = i * self.k; &self.data[start..start + self.k] } @@ -227,37 +257,6 @@ where } } -/// A neighbor entry: node index and distance -/// Used both in the k-NN heap and in the final AnnGraph -#[derive(Debug, Clone, Copy)] -pub struct Neighbor { - pub index: usize, - pub distance: f64, -} - -impl PartialEq for Neighbor { - fn eq(&self, other: &Self) -> bool { - self.distance == other.distance - } -} - -impl Eq for Neighbor {} - -impl PartialOrd for Neighbor { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Neighbor { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Max-heap: larger distances have higher priority - self.distance - .partial_cmp(&other.distance) - .unwrap_or(std::cmp::Ordering::Equal) - } -} - /// NN-Descent algorithm for approximate k-NN graph construction /// /// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" From 4ec00638c315fa16bd26e783c17268281efdbbd4 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 13:19:35 -0600 Subject: [PATCH 12/12] Switch to 32-bit types for memory efficiency - Use NodeId (u32) type alias for node indices - Use f32 for distances instead of f64 - Add assertion that data.len() <= 2^32 with documented panic - This halves memory usage for internal data structures For n=1 billion, k=20, this saves ~120 GB of memory in the AnnGraph alone. --- crates/famst/src/lib.rs | 175 +++++++++++++++++++++------------------- 1 file changed, 94 insertions(+), 81 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 46e437d..a29a014 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -20,11 +20,11 @@ use std::collections::BinaryHeap; pub struct Edge { pub u: usize, pub v: usize, - pub distance: f64, + pub distance: f32, } impl Edge { - pub fn new(u: usize, v: usize, distance: f64) -> Self { + fn new(u: usize, v: usize, distance: f32) -> Self { Edge { u, v, distance } } } @@ -69,11 +69,14 @@ impl UnionFind { } } -/// A neighbor entry: node index and distance +/// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) +type NodeId = u32; + +/// A neighbor entry: node index and distance (32-bit for memory efficiency) #[derive(Debug, Clone, Copy)] struct Neighbor { - index: usize, - distance: f64, + index: NodeId, + distance: f32, } impl PartialEq for Neighbor { @@ -120,10 +123,6 @@ impl AnnGraph { self.n } - fn k(&self) -> usize { - self.k - } - /// Get the neighbors of point i fn neighbors(&self, i: usize) -> &[Neighbor] { let start = i * self.k; @@ -162,14 +161,14 @@ pub struct FamstResult { /// MST edges pub edges: Vec, /// Total weight of the MST - pub total_weight: f64, + pub total_weight: f32, } /// Main FAMST algorithm implementation /// /// Generic over: /// - `T`: The data type stored at each point -/// - `D`: Distance function `Fn(&T, &T) -> f64` +/// - `D`: Distance function `Fn(&T, &T) -> f32` /// /// # Arguments /// * `data` - Slice of data points @@ -178,14 +177,20 @@ pub struct FamstResult { /// /// # Returns /// The approximate MST as a list of edges +/// +/// # Panics +/// Panics if `data.len() >= 2^32` (more than ~4 billion points). pub fn famst(data: &[T], distance_fn: D, config: &FamstConfig) -> FamstResult where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, { famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) } /// FAMST with custom RNG. (We use a seeded RNG in tests for reproducibility.) +/// +/// # Panics +/// Panics if `data.len() >= 2^32` (more than ~4 billion points). pub fn famst_with_rng( data: &[T], distance_fn: D, @@ -193,10 +198,15 @@ pub fn famst_with_rng( rng: &mut R, ) -> FamstResult where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, R: Rng, { let n = data.len(); + assert!( + n <= NodeId::MAX as usize, + "famst: data length {n} exceeds maximum supported size of 2^32" + ); + if n <= 1 { return FamstResult { edges: vec![], @@ -263,7 +273,7 @@ where /// by Wei Dong, Charikar Moses, and Kai Li (2011) fn nn_descent(data: &[T], distance_fn: &D, config: &FamstConfig, rng: &mut R) -> AnnGraph where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, R: Rng, { let n = data.len(); @@ -274,12 +284,12 @@ where } // Helper: check if sorted vec contains value - fn sorted_contains(v: &[usize], x: usize) -> bool { + fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { v.binary_search(&x).is_ok() } // Helper: insert into sorted vec, returns true if inserted (was not present) - fn sorted_insert(v: &mut Vec, x: usize) -> bool { + fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { match v.binary_search(&x) { Ok(_) => false, Err(pos) => { @@ -290,7 +300,7 @@ where } // Helper: remove from sorted vec - fn sorted_remove(v: &mut Vec, x: usize) { + fn sorted_remove(v: &mut Vec, x: NodeId) { if let Ok(pos) = v.binary_search(&x) { v.remove(pos); } @@ -299,7 +309,7 @@ where // Initialize with random neighbors using max-heap for each point // neighbor_lists[i] is kept sorted by index for O(log k) membership tests let mut heaps: Vec> = Vec::with_capacity(n); - let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; + let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; for i in 0..n { let mut heap = BinaryHeap::with_capacity(k); @@ -312,19 +322,19 @@ where for t in range_start..effective_n { let j = rng.gen_range(0..=t); // Map j to actual index, skipping i - let actual_j = if j >= i { j + 1 } else { j }; + let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; if !sorted_insert(&mut neighbor_lists[i], actual_j) { // j was already selected, so add t instead - let actual_t = if t >= i { t + 1 } else { t }; + let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; sorted_insert(&mut neighbor_lists[i], actual_t); - let d = distance_fn(&data[i], &data[actual_t]); + let d = distance_fn(&data[i], &data[actual_t as usize]); heap.push(Neighbor { index: actual_t, distance: d, }); } else { - let d = distance_fn(&data[i], &data[actual_j]); + let d = distance_fn(&data[i], &data[actual_j as usize]); heap.push(Neighbor { index: actual_j, distance: d, @@ -336,11 +346,11 @@ where // Build reverse neighbor lists (who has me as a neighbor) // Returns sorted vecs for each point - let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { - let mut reverse: Vec> = vec![Vec::new(); n]; + let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { + let mut reverse: Vec> = vec![Vec::new(); n]; for (i, neighbors) in neighbor_lists.iter().enumerate() { for &j in neighbors { - reverse[j].push(i); + reverse[j as usize].push(i as NodeId); } } // Sort each reverse list (they're built in order of i, so already sorted) @@ -355,7 +365,7 @@ where // For each point, explore neighbors of neighbors for i in 0..n { // Collect candidates: neighbors and reverse neighbors - let mut candidates: Vec = Vec::new(); + let mut candidates: Vec = Vec::new(); // Sample from forward neighbors let mut sampled_forward = neighbor_lists[i].clone(); @@ -372,15 +382,16 @@ where sampled_reverse.truncate(sample_size); // Neighbors of neighbors + let i_id = i as NodeId; for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for &nn in &neighbor_lists[neighbor] { - if nn != i && !sorted_contains(&neighbor_lists[i], nn) { + for &nn in &neighbor_lists[neighbor as usize] { + if nn != i_id && !sorted_contains(&neighbor_lists[i], nn) { candidates.push(nn); } } // Also check reverse neighbors of neighbors - for &rn in &reverse_neighbors[neighbor] { - if rn != i && !sorted_contains(&neighbor_lists[i], rn) { + for &rn in &reverse_neighbors[neighbor as usize] { + if rn != i_id && !sorted_contains(&neighbor_lists[i], rn) { candidates.push(rn); } } @@ -392,7 +403,7 @@ where // Try to improve neighbors for c in candidates { - let d = distance_fn(&data[i], &data[c]); + let d = distance_fn(&data[i], &data[c as usize]); // Check if this is better than the worst current neighbor if let Some(worst) = heaps[i].peek() { @@ -432,16 +443,16 @@ where /// Find connected components in the ANN graph using DFS /// Returns the undirected graph adjacency list (sorted vecs) and component assignments -fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { +fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { let n = ann_graph.n(); // Build undirected graph from directed ANN graph using sorted vecs - let mut graph: Vec> = vec![Vec::new(); n]; + let mut graph: Vec> = vec![Vec::new(); n]; for i in 0..n { for neighbor in ann_graph.neighbors(i) { let j = neighbor.index; graph[i].push(j); - graph[j].push(i); + graph[j as usize].push(i as NodeId); } } // Sort and deduplicate each adjacency list @@ -452,7 +463,7 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { // DFS to find components let mut visited = vec![false; n]; - let mut components: Vec> = Vec::new(); + let mut components: Vec> = Vec::new(); for start in 0..n { if visited[start] { @@ -460,17 +471,17 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { } let mut component = Vec::new(); - let mut stack = vec![start]; + let mut stack = vec![start as NodeId]; while let Some(u) = stack.pop() { - if visited[u] { + if visited[u as usize] { continue; } - visited[u] = true; + visited[u as usize] = true; component.push(u); - for &v in &graph[u] { - if !visited[v] { + for &v in &graph[u as usize] { + if !visited[v as usize] { stack.push(v); } } @@ -485,13 +496,13 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { /// Add random edges between components (Algorithm 3 in the paper) fn add_random_edges( data: &[T], - components: &[Vec], + components: &[Vec], lambda: usize, distance_fn: &D, rng: &mut R, ) -> (Vec, Vec<(usize, usize)>) where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, R: Rng, { let t = components.len(); @@ -506,8 +517,8 @@ where // Generate λ² candidate edges for _ in 0..lambda_sq { - let u = *components[i].choose(rng).unwrap(); - let v = *components[j].choose(rng).unwrap(); + let u = *components[i].choose(rng).unwrap() as usize; + let v = *components[j].choose(rng).unwrap() as usize; let d = distance_fn(&data[u], &data[v]); candidates.push(Edge::new(u, v, d)); } @@ -528,14 +539,14 @@ where /// Refine inter-component edges (Algorithm 4 in the paper) fn refine_edges( data: &[T], - undirected_graph: &[Vec], - components: &[Vec], + undirected_graph: &[Vec], + components: &[Vec], edges: &[Edge], edge_components: &[(usize, usize)], distance_fn: &D, ) -> (Vec, usize) where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, { let n = data.len(); @@ -543,7 +554,7 @@ where let mut node_to_component: Vec = vec![0; n]; for (comp_idx, component) in components.iter().enumerate() { for &node in component { - node_to_component[node] = comp_idx; + node_to_component[node as usize] = comp_idx; } } @@ -558,6 +569,7 @@ where // Get neighbors of u that are in component ci // (undirected_graph is sorted, so we can iterate directly) for &u_prime in &undirected_graph[edge.u] { + let u_prime = u_prime as usize; if node_to_component[u_prime] == ci && u_prime != edge.v { let d_prime = distance_fn(&data[u_prime], &data[best_v]); if d_prime < best_d { @@ -569,6 +581,7 @@ where // Get neighbors of v that are in component cj for &v_prime in &undirected_graph[edge.v] { + let v_prime = v_prime as usize; if node_to_component[v_prime] == cj && v_prime != edge.u { let d_prime = distance_fn(&data[best_u], &data[v_prime]); if d_prime < best_d { @@ -595,7 +608,7 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec f64 { + /// Manhattan distance for slices of f32 + fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() } - /// Euclidean distance for slices of f64 - pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + /// Euclidean distance for slices of f32 + fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { a.iter() .zip(b.iter()) .map(|(x, y)| (x - y).powi(2)) - .sum::() + .sum::() .sqrt() } #[test] fn test_empty_input() { - let points: Vec> = vec![]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let points: Vec> = vec![]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let result = famst(&points, distance, &FamstConfig::default()); assert_eq!(result.edges.len(), 0); assert_eq!(result.total_weight, 0.0); @@ -657,8 +670,8 @@ mod tests { #[test] fn test_single_point() { - let points: Vec> = vec![vec![1.0, 2.0, 3.0]]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let points: Vec> = vec![vec![1.0, 2.0, 3.0]]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let result = famst(&points, distance, &FamstConfig::default()); assert_eq!(result.edges.len(), 0); assert_eq!(result.total_weight, 0.0); @@ -667,8 +680,8 @@ mod tests { #[test] fn test_k_greater_than_n() { // 3 points but k=20 (default), so k >= n - let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig::default(); // k=20 > n=3 let result = famst(&points, distance, &config); assert_eq!(result.edges.len(), 2); // MST has n-1 edges @@ -687,13 +700,13 @@ mod tests { #[test] fn test_simple_mst() { // Simple 2D points forming a triangle - let points: Vec> = vec![ + let points: Vec> = vec![ vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.866], // Equilateral triangle ]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 2, ..Default::default() @@ -708,9 +721,9 @@ mod tests { #[test] fn test_line_points() { // Points on a line - let points: Vec> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]]; + let points: Vec> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 2, ..Default::default() @@ -727,7 +740,7 @@ mod tests { #[test] fn test_disconnected_components() { // Two clusters far apart - let points: Vec> = vec![ + let points: Vec> = vec![ vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.5], @@ -737,7 +750,7 @@ mod tests { ]; // k=1 will likely create disconnected components - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 1, lambda: 3, @@ -754,9 +767,9 @@ mod tests { #[test] fn test_custom_distance() { // Test with Manhattan distance - let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; - let distance = |a: &Vec, b: &Vec| manhattan_distance(a, b); + let distance = |a: &Vec, b: &Vec| manhattan_distance(a, b); let config = FamstConfig { k: 2, ..Default::default() @@ -775,12 +788,12 @@ mod tests { // Test with a custom struct #[derive(Clone)] struct Point3D { - x: f64, - y: f64, - z: f64, + x: f32, + y: f32, + z: f32, } - fn point_distance(a: &Point3D, b: &Point3D) -> f64 { + fn point_distance(a: &Point3D, b: &Point3D) -> f32 { ((a.x - b.x).powi(2) + (a.y - b.y).powi(2) + (a.z - b.z).powi(2)).sqrt() } @@ -836,7 +849,7 @@ mod tests { ]; let points_per_cluster = 20; - let mut points: Vec> = Vec::new(); + let mut points: Vec> = Vec::new(); for center in &cluster_centers { for _ in 0..points_per_cluster { @@ -849,7 +862,7 @@ mod tests { } let n = points.len(); - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); // Use small k to create disconnected components // With k=3 and 20 points per cluster spread over 5 clusters, @@ -899,9 +912,9 @@ mod tests { } /// Compute exact MST using Kruskal's algorithm on complete graph - fn exact_mst_weight(data: &[T], distance_fn: D) -> f64 + fn exact_mst_weight(data: &[T], distance_fn: D) -> f32 where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, { let n = data.len(); if n <= 1 { @@ -909,7 +922,7 @@ mod tests { } // Build all edges - let mut edges: Vec<(usize, usize, f64)> = Vec::with_capacity(n * (n - 1) / 2); + let mut edges: Vec<(usize, usize, f32)> = Vec::with_capacity(n * (n - 1) / 2); for i in 0..n { for j in (i + 1)..n { let d = distance_fn(&data[i], &data[j]); @@ -950,12 +963,12 @@ mod tests { let mut rng = StdRng::seed_from_u64(12345); let dist = Uniform::new(0.0, 1000.0); - let points: Vec> = (0..N) + let points: Vec> = (0..N) .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) .collect(); println!("Running FAMST with NN-Descent..."); - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 20, lambda: 5, @@ -985,11 +998,11 @@ mod tests { let mut rng = StdRng::seed_from_u64(99999); let dist = Uniform::new(0.0, 100.0); - let points: Vec> = (0..N) + let points: Vec> = (0..N) .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) .collect(); - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); // Compute exact MST let exact_weight = exact_mst_weight(&points, distance);