Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ func (s *SelectWithUnionQuery) Pos() token.Position { return s.Position }
func (s *SelectWithUnionQuery) End() token.Position { return s.Position }
func (s *SelectWithUnionQuery) statementNode() {}

// SelectIntersectExceptQuery represents SELECT ... INTERSECT/EXCEPT ... queries.
type SelectIntersectExceptQuery struct {
Position token.Position `json:"-"`
Selects []Statement `json:"selects"`
}

func (s *SelectIntersectExceptQuery) Pos() token.Position { return s.Position }
func (s *SelectIntersectExceptQuery) End() token.Position { return s.Position }
func (s *SelectIntersectExceptQuery) statementNode() {}

// SelectQuery represents a SELECT statement.
type SelectQuery struct {
Position token.Position `json:"-"`
Expand Down Expand Up @@ -212,6 +222,8 @@ type InsertQuery struct {
Function *FunctionCall `json:"function,omitempty"` // For INSERT INTO FUNCTION syntax
Columns []*Identifier `json:"columns,omitempty"`
PartitionBy Expression `json:"partition_by,omitempty"` // For PARTITION BY clause
Infile string `json:"infile,omitempty"` // For FROM INFILE clause
Compression string `json:"compression,omitempty"` // For COMPRESSION clause
Select Statement `json:"select,omitempty"`
Format *Identifier `json:"format,omitempty"`
HasSettings bool `json:"has_settings,omitempty"` // For SETTINGS clause
Expand Down
2 changes: 2 additions & 0 deletions internal/explain/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func Node(sb *strings.Builder, node interface{}, depth int) {
// Select statements
case *ast.SelectWithUnionQuery:
explainSelectWithUnionQuery(sb, n, indent, depth)
case *ast.SelectIntersectExceptQuery:
explainSelectIntersectExceptQuery(sb, n, indent, depth)
case *ast.SelectQuery:
explainSelectQuery(sb, n, indent, depth)

Expand Down
137 changes: 127 additions & 10 deletions internal/explain/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,58 @@ func explainLiteral(sb *strings.Builder, n *ast.Literal, indent string, depth in
fmt.Fprintf(sb, "%s ExpressionList\n", indent)
return
}
hasComplexExpr := false
// Check if we should render as Function array
// This happens when:
// 1. Contains non-literal, non-negation expressions OR
// 2. Contains tuples OR
// 3. Contains nested arrays that all have exactly 1 element (homogeneous single-element arrays) OR
// 4. Contains nested arrays with non-literal expressions OR
// 5. Contains nested arrays that are empty or contain tuples/non-literals
shouldUseFunctionArray := false
allAreSingleElementArrays := true
hasNestedArrays := false
nestedArraysNeedFunctionFormat := false

for _, e := range exprs {
if !isSimpleLiteralOrNegation(e) {
hasComplexExpr = true
break
if lit, ok := e.(*ast.Literal); ok {
if lit.Type == ast.LiteralArray {
hasNestedArrays = true
// Check if this inner array has exactly 1 element
if innerExprs, ok := lit.Value.([]ast.Expression); ok {
if len(innerExprs) != 1 {
allAreSingleElementArrays = false
}
// Check if inner array needs Function array format:
// - Contains non-literal expressions OR
// - Contains tuples OR
// - Is empty OR
// - Contains empty arrays
if containsNonLiteralExpressions(innerExprs) ||
len(innerExprs) == 0 ||
containsTuples(innerExprs) ||
containsEmptyArrays(innerExprs) {
nestedArraysNeedFunctionFormat = true
}
} else {
allAreSingleElementArrays = false
}
} else if lit.Type == ast.LiteralTuple {
// Tuples are complex
shouldUseFunctionArray = true
}
} else if !isSimpleLiteralOrNegation(e) {
shouldUseFunctionArray = true
}
}
if hasComplexExpr {

// Use Function array when:
// - nested arrays that are ALL single-element
// - nested arrays that need Function format (contain non-literals, tuples, or empty arrays)
if hasNestedArrays && (allAreSingleElementArrays || nestedArraysNeedFunctionFormat) {
shouldUseFunctionArray = true
}

if shouldUseFunctionArray {
// Render as Function array instead of Literal
fmt.Fprintf(sb, "%sFunction array (children %d)\n", indent, 1)
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(exprs))
Expand Down Expand Up @@ -124,6 +168,58 @@ func isSimpleLiteralOrNegation(e ast.Expression) bool {
return false
}

// containsOnlyArraysOrTuples checks if a slice of expressions contains
// only array or tuple literals (including empty arrays).
// Returns true if the slice is empty or contains only arrays/tuples.
func containsOnlyArraysOrTuples(exprs []ast.Expression) bool {
if len(exprs) == 0 {
return true // empty is considered "only arrays"
}
for _, e := range exprs {
if lit, ok := e.(*ast.Literal); ok {
if lit.Type != ast.LiteralArray && lit.Type != ast.LiteralTuple {
return false
}
} else {
return false
}
}
return true
}

// containsNonLiteralExpressions checks if a slice of expressions contains
// any non-literal expressions (identifiers, function calls, etc.)
func containsNonLiteralExpressions(exprs []ast.Expression) bool {
for _, e := range exprs {
if _, ok := e.(*ast.Literal); !ok {
return true
}
}
return false
}

// containsTuples checks if a slice of expressions contains any tuple literals
func containsTuples(exprs []ast.Expression) bool {
for _, e := range exprs {
if lit, ok := e.(*ast.Literal); ok && lit.Type == ast.LiteralTuple {
return true
}
}
return false
}

// containsEmptyArrays checks if a slice of expressions contains any empty array literals
func containsEmptyArrays(exprs []ast.Expression) bool {
for _, e := range exprs {
if lit, ok := e.(*ast.Literal); ok && lit.Type == ast.LiteralArray {
if innerExprs, ok := lit.Value.([]ast.Expression); ok && len(innerExprs) == 0 {
return true
}
}
}
return false
}

func explainBinaryExpr(sb *strings.Builder, n *ast.BinaryExpr, indent string, depth int) {
// Convert operator to function name
fnName := OperatorToFunction(n.Op)
Expand Down Expand Up @@ -303,11 +399,20 @@ func explainAsterisk(sb *strings.Builder, n *ast.Asterisk, indent string) {

func explainWithElement(sb *strings.Builder, n *ast.WithElement, indent string, depth int) {
// For WITH elements, we need to show the underlying expression with the name as alias
// When name is empty, don't show the alias part
switch e := n.Query.(type) {
case *ast.Literal:
fmt.Fprintf(sb, "%sLiteral %s (alias %s)\n", indent, FormatLiteral(e), n.Name)
if n.Name != "" {
fmt.Fprintf(sb, "%sLiteral %s (alias %s)\n", indent, FormatLiteral(e), n.Name)
} else {
fmt.Fprintf(sb, "%sLiteral %s\n", indent, FormatLiteral(e))
}
case *ast.Identifier:
fmt.Fprintf(sb, "%sIdentifier %s (alias %s)\n", indent, e.Name(), n.Name)
if n.Name != "" {
fmt.Fprintf(sb, "%sIdentifier %s (alias %s)\n", indent, e.Name(), n.Name)
} else {
fmt.Fprintf(sb, "%sIdentifier %s\n", indent, e.Name())
}
case *ast.FunctionCall:
explainFunctionCallWithAlias(sb, e, n.Name, indent, depth)
case *ast.BinaryExpr:
Expand All @@ -316,19 +421,31 @@ func explainWithElement(sb *strings.Builder, n *ast.WithElement, indent string,
// For || (concat) operator, flatten chained concatenations
if e.Op == "||" {
operands := collectConcatOperands(e)
fmt.Fprintf(sb, "%sFunction %s (alias %s) (children %d)\n", indent, fnName, n.Name, 1)
if n.Name != "" {
fmt.Fprintf(sb, "%sFunction %s (alias %s) (children %d)\n", indent, fnName, n.Name, 1)
} else {
fmt.Fprintf(sb, "%sFunction %s (children %d)\n", indent, fnName, 1)
}
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(operands))
for _, op := range operands {
Node(sb, op, depth+2)
}
} else {
fmt.Fprintf(sb, "%sFunction %s (alias %s) (children %d)\n", indent, fnName, n.Name, 1)
if n.Name != "" {
fmt.Fprintf(sb, "%sFunction %s (alias %s) (children %d)\n", indent, fnName, n.Name, 1)
} else {
fmt.Fprintf(sb, "%sFunction %s (children %d)\n", indent, fnName, 1)
}
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, 2)
Node(sb, e.Left, depth+2)
Node(sb, e.Right, depth+2)
}
case *ast.Subquery:
fmt.Fprintf(sb, "%sSubquery (alias %s) (children %d)\n", indent, n.Name, 1)
if n.Name != "" {
fmt.Fprintf(sb, "%sSubquery (alias %s) (children %d)\n", indent, n.Name, 1)
} else {
fmt.Fprintf(sb, "%sSubquery (children %d)\n", indent, 1)
}
Node(sb, e.Query, depth+1)
default:
// For other types, just output the expression (alias may be lost)
Expand Down
1 change: 0 additions & 1 deletion internal/explain/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ func NormalizeFunctionName(name string) string {
"lcase": "lower",
"ucase": "upper",
"mid": "substring",
"substr": "substring",
"ceiling": "ceil",
"ln": "log",
"log10": "log10",
Expand Down
37 changes: 37 additions & 0 deletions internal/explain/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,39 @@ func explainInExpr(sb *strings.Builder, n *ast.InExpr, indent string, depth int)
fnName = "global" + strings.Title(fnName)
}
fmt.Fprintf(sb, "%sFunction %s (children %d)\n", indent, fnName, 1)

// Determine if the IN list should be combined into a single tuple literal
// This happens when we have multiple literals of the same type:
// - All numeric literals (integers/floats)
// - All tuple literals
canBeTupleLiteral := false
if n.Query == nil && len(n.List) > 1 {
allNumeric := true
allTuples := true
for _, item := range n.List {
if lit, ok := item.(*ast.Literal); ok {
if lit.Type != ast.LiteralInteger && lit.Type != ast.LiteralFloat {
allNumeric = false
}
if lit.Type != ast.LiteralTuple {
allTuples = false
}
} else {
allNumeric = false
allTuples = false
break
}
}
canBeTupleLiteral = allNumeric || allTuples
}

// Count arguments: expr + list items or subquery
argCount := 1
if n.Query != nil {
argCount++
} else if canBeTupleLiteral {
// Multiple literals will be combined into a single tuple
argCount++
} else {
// Check if we have a single tuple literal that should be wrapped in Function tuple
if len(n.List) == 1 {
Expand All @@ -198,10 +227,18 @@ func explainInExpr(sb *strings.Builder, n *ast.InExpr, indent string, depth int)
}
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, argCount)
Node(sb, n.Expr, depth+2)

if n.Query != nil {
// Subqueries in IN should be wrapped in Subquery node
fmt.Fprintf(sb, "%s Subquery (children %d)\n", indent, 1)
Node(sb, n.Query, depth+3)
} else if canBeTupleLiteral {
// Combine multiple literals into a single Tuple literal
tupleLit := &ast.Literal{
Type: ast.LiteralTuple,
Value: n.List,
}
fmt.Fprintf(sb, "%s Literal %s\n", indent, FormatLiteral(tupleLit))
} else if len(n.List) == 1 {
// Single element in the list
// If it's a tuple literal, wrap it in Function tuple
Expand Down
7 changes: 7 additions & 0 deletions internal/explain/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import (
"github.com/sqlc-dev/doubleclick/ast"
)

func explainSelectIntersectExceptQuery(sb *strings.Builder, n *ast.SelectIntersectExceptQuery, indent string, depth int) {
fmt.Fprintf(sb, "%sSelectIntersectExceptQuery (children %d)\n", indent, len(n.Selects))
for _, sel := range n.Selects {
Node(sb, sel, depth+1)
}
}

func explainSelectWithUnionQuery(sb *strings.Builder, n *ast.SelectWithUnionQuery, indent string, depth int) {
children := countSelectUnionChildren(n)
fmt.Fprintf(sb, "%sSelectWithUnionQuery (children %d)\n", indent, children)
Expand Down
15 changes: 15 additions & 0 deletions internal/explain/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
func explainInsertQuery(sb *strings.Builder, n *ast.InsertQuery, indent string, depth int) {
// Count children
children := 0
if n.Infile != "" {
children++
}
if n.Compression != "" {
children++
}
if n.Function != nil {
children++
} else if n.Table != "" {
Expand All @@ -24,6 +30,15 @@ func explainInsertQuery(sb *strings.Builder, n *ast.InsertQuery, indent string,
// Note: InsertQuery uses 3 spaces after name in ClickHouse explain
fmt.Fprintf(sb, "%sInsertQuery (children %d)\n", indent, children)

// FROM INFILE path comes first
if n.Infile != "" {
fmt.Fprintf(sb, "%s Literal \\'%s\\'\n", indent, n.Infile)
}
// COMPRESSION value comes next
if n.Compression != "" {
fmt.Fprintf(sb, "%s Literal \\'%s\\'\n", indent, n.Compression)
}

if n.Function != nil {
Node(sb, n.Function, depth+1)
} else if n.Table != "" {
Expand Down
18 changes: 14 additions & 4 deletions lexer/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,26 @@ func (l *Lexer) readBlockComment() Item {
sb.WriteRune(l.ch)
l.readChar()

for !l.eof {
// Track nesting level for nested comments (ClickHouse supports nested /* */ comments)
nesting := 1

for !l.eof && nesting > 0 {
if l.ch == '*' && l.peekChar() == '/' {
sb.WriteRune(l.ch)
l.readChar()
sb.WriteRune(l.ch)
l.readChar()
break
nesting--
} else if l.ch == '/' && l.peekChar() == '*' {
sb.WriteRune(l.ch)
l.readChar()
sb.WriteRune(l.ch)
l.readChar()
nesting++
} else {
sb.WriteRune(l.ch)
l.readChar()
}
sb.WriteRune(l.ch)
l.readChar()
}
return Item{Token: token.COMMENT, Value: sb.String(), Pos: pos}
}
Expand Down
Loading