From 0c7fbafda1eea7a37884dcb96eea9f662e1b439c Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Tue, 6 Jan 2026 20:25:51 +0200 Subject: [PATCH] perf(optimizer): add count threshold comparisons Optimize the following patterns: - count(arr, pred) > N - count(arr, pred) >= N - count(arr, pred) < N - count(arr, pred) <= N Add a threshold check inside the count loop. When the count reaches the threshold, the loop exits early instead of scanning the entire array. This is implemented via a new Threshold field on BuiltinNode that the optimizer sets when detecting these patterns. The compiler then emits bytecode that checks the count against the threshold after each increment and jumps out of the loop when reached. Signed-off-by: Ville Vesilehto --- ast/node.go | 1 + compiler/compiler.go | 17 ++ optimizer/count_threshold.go | 54 ++++++ optimizer/count_threshold_test.go | 278 ++++++++++++++++++++++++++++++ optimizer/optimizer.go | 1 + 5 files changed, 351 insertions(+) create mode 100644 optimizer/count_threshold.go create mode 100644 optimizer/count_threshold_test.go diff --git a/ast/node.go b/ast/node.go index 198efa59..fbb9ae82 100644 --- a/ast/node.go +++ b/ast/node.go @@ -187,6 +187,7 @@ type BuiltinNode struct { Arguments []Node // Arguments of the builtin function. Throws bool // If true then accessing a field or array index can throw an error. Used by optimizer. Map Node // Used by optimizer to fold filter() and map() builtins. + Threshold *int // Used by optimizer for count() early termination. } // PredicateNode represents a predicate. diff --git a/compiler/compiler.go b/compiler/compiler.go index 951385cd..ed8942c9 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -937,6 +937,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.compile(node.Arguments[0]) c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) + var loopBreak int c.emitLoop(func() { if len(node.Arguments) == 2 { c.compile(node.Arguments[1]) @@ -945,9 +946,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { } c.emitCond(func() { c.emit(OpIncrementCount) + // Early termination if threshold is set + if node.Threshold != nil { + c.emit(OpGetCount) + c.emit(OpInt, *node.Threshold) + c.emit(OpMoreOrEqual) + loopBreak = c.emit(OpJumpIfTrue, placeholder) + c.emit(OpPop) + } }) }) c.emit(OpGetCount) + if node.Threshold != nil { + end := c.emit(OpJump, placeholder) + c.patchJump(loopBreak) + // Early exit path: pop the bool comparison result, push count + c.emit(OpPop) + c.emit(OpGetCount) + c.patchJump(end) + } c.emit(OpEnd) return diff --git a/optimizer/count_threshold.go b/optimizer/count_threshold.go new file mode 100644 index 00000000..d045760b --- /dev/null +++ b/optimizer/count_threshold.go @@ -0,0 +1,54 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" +) + +// countThreshold optimizes count comparisons by setting a threshold for early termination. +// The threshold allows the count loop to exit early once enough matches are found. +// Patterns: +// - count(arr, pred) > N → threshold = N + 1 (exit proves > N is true) +// - count(arr, pred) >= N → threshold = N (exit proves >= N is true) +// - count(arr, pred) < N → threshold = N (exit proves < N is false) +// - count(arr, pred) <= N → threshold = N + 1 (exit proves <= N is false) +type countThreshold struct{} + +func (*countThreshold) Visit(node *Node) { + binary, ok := (*node).(*BinaryNode) + if !ok { + return + } + + count, ok := binary.Left.(*BuiltinNode) + if !ok || count.Name != "count" || len(count.Arguments) != 2 { + return + } + + integer, ok := binary.Right.(*IntegerNode) + if !ok || integer.Value < 0 { + return + } + + var threshold int + switch binary.Operator { + case ">": + threshold = integer.Value + 1 + case ">=": + threshold = integer.Value + case "<": + threshold = integer.Value + case "<=": + threshold = integer.Value + 1 + default: + return + } + + // Skip if threshold is 0 or 1 (handled by count_any optimizer) + if threshold <= 1 { + return + } + + // Set threshold on the count node for early termination + // The original comparison remains unchanged + count.Threshold = &threshold +} diff --git a/optimizer/count_threshold_test.go b/optimizer/count_threshold_test.go new file mode 100644 index 00000000..3bac6fc3 --- /dev/null +++ b/optimizer/count_threshold_test.go @@ -0,0 +1,278 @@ +package optimizer_test + +import ( + "testing" + + "github.com/expr-lang/expr" + . "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/internal/testify/assert" + "github.com/expr-lang/expr/internal/testify/require" + "github.com/expr-lang/expr/optimizer" + "github.com/expr-lang/expr/parser" + "github.com/expr-lang/expr/vm" +) + +func TestOptimize_count_threshold_gt(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) > 100`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain >, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, ">", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 101, *count.Threshold) // threshold = N + 1 for > operator +} + +func TestOptimize_count_threshold_gte(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) >= 50`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain >=, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, ">=", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 50, *count.Threshold) // threshold = N for >= operator +} + +func TestOptimize_count_threshold_lt(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) < 100`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain <, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, "<", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 100, *count.Threshold) // threshold = N for < operator +} + +func TestOptimize_count_threshold_lte(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) <= 50`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain <=, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, "<=", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 51, *count.Threshold) // threshold = N + 1 for <= operator +} + +func TestOptimize_count_threshold_correctness(t *testing.T) { + tests := []struct { + expr string + want bool + }{ + // count > N (threshold = N + 1) + {`count(1..1000, # <= 100) > 50`, true}, // 100 matches > 50 + {`count(1..1000, # <= 100) > 100`, false}, // 100 matches not > 100 + {`count(1..1000, # <= 100) > 99`, true}, // 100 matches > 99 + {`count(1..100, # > 0) > 50`, true}, // 100 matches > 50 + {`count(1..100, # > 0) > 100`, false}, // 100 matches not > 100 + + // count >= N (threshold = N) + {`count(1..1000, # <= 100) >= 100`, true}, // 100 matches >= 100 + {`count(1..1000, # <= 100) >= 101`, false}, // 100 matches not >= 101 + {`count(1..100, # > 0) >= 50`, true}, // 100 matches >= 50 + {`count(1..100, # > 0) >= 100`, true}, // 100 matches >= 100 + + // count < N (threshold = N) + {`count(1..1000, # <= 100) < 101`, true}, // 100 matches < 101 + {`count(1..1000, # <= 100) < 100`, false}, // 100 matches not < 100 + {`count(1..1000, # <= 100) < 50`, false}, // 100 matches not < 50 + {`count(1..100, # > 0) < 101`, true}, // 100 matches < 101 + {`count(1..100, # > 0) < 100`, false}, // 100 matches not < 100 + + // count <= N (threshold = N + 1) + {`count(1..1000, # <= 100) <= 100`, true}, // 100 matches <= 100 + {`count(1..1000, # <= 100) <= 99`, false}, // 100 matches not <= 99 + {`count(1..1000, # <= 100) <= 50`, false}, // 100 matches not <= 50 + {`count(1..100, # > 0) <= 100`, true}, // 100 matches <= 100 + {`count(1..100, # > 0) <= 99`, false}, // 100 matches not <= 99 + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_count_threshold_no_optimization(t *testing.T) { + // These should NOT get a threshold (handled by count_any or not optimizable) + tests := []struct { + code string + threshold bool + }{ + {`count(items, .active) > 0`, false}, // handled by count_any + {`count(items, .active) >= 1`, false}, // handled by count_any + {`count(items, .active) < 1`, false}, // threshold = 1, skipped + {`count(items, .active) <= 0`, false}, // threshold = 1, skipped + {`count(items, .active) == 10`, false}, // not supported + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + tree, err := parser.Parse(tt.code) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Check if count has threshold set + var count *BuiltinNode + if binary, ok := tree.Node.(*BinaryNode); ok { + count, _ = binary.Left.(*BuiltinNode) + } else if builtin, ok := tree.Node.(*BuiltinNode); ok { + count = builtin + } + + if count != nil && count.Name == "count" { + if tt.threshold { + assert.NotNil(t, count.Threshold, "expected threshold to be set") + } else { + assert.Nil(t, count.Threshold, "expected threshold to be nil") + } + } + }) + } +} + +// Benchmark: count > 100 with early match (element 101 matches early) +func BenchmarkCountThresholdEarlyMatch(b *testing.B) { + // Array of 10000 elements, all match predicate, threshold is 101 + // Should exit after ~101 iterations + program, _ := expr.Compile(`count(1..10000, # > 0) > 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count >= 50 with early match +func BenchmarkCountThresholdGteEarlyMatch(b *testing.B) { + // All elements match, threshold is 50 + // Should exit after ~50 iterations + program, _ := expr.Compile(`count(1..10000, # > 0) >= 50`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count > 100 with no early exit (not enough matches) +func BenchmarkCountThresholdNoEarlyExit(b *testing.B) { + // Only 100 elements match (# <= 100), threshold is 101 + // Must scan entire array + program, _ := expr.Compile(`count(1..10000, # <= 100) > 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: Large threshold with early match +func BenchmarkCountThresholdLargeEarlyMatch(b *testing.B) { + // All 10000 match, threshold is 1000 + // Should exit after ~1000 iterations + program, _ := expr.Compile(`count(1..10000, # > 0) > 999`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count < N with early exit (result is false) +func BenchmarkCountThresholdLtEarlyExit(b *testing.B) { + // All 10000 match, threshold is 100 + // Should exit after ~100 iterations with result = false + program, _ := expr.Compile(`count(1..10000, # > 0) < 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count <= N with early exit (result is false) +func BenchmarkCountThresholdLteEarlyExit(b *testing.B) { + // All 10000 match, threshold is 51 + // Should exit after ~51 iterations with result = false + program, _ := expr.Compile(`count(1..10000, # > 0) <= 50`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count < N without early exit (result is true) +func BenchmarkCountThresholdLtNoEarlyExit(b *testing.B) { + // Only 100 elements match (# <= 100), threshold is 200 + // Must scan entire array, result = true + program, _ := expr.Compile(`count(1..10000, # <= 100) < 200`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count <= N without early exit (result is true) +func BenchmarkCountThresholdLteNoEarlyExit(b *testing.B) { + // Only 100 elements match (# <= 100), threshold is 101 + // Must scan entire array, result = true + program, _ := expr.Compile(`count(1..10000, # <= 100) <= 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index fedf0208..9e4c75d3 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -44,6 +44,7 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &sumArray{}) Walk(node, &sumMap{}) Walk(node, &countAny{}) + Walk(node, &countThreshold{}) return nil }