diff --git a/parser.go b/parser.go index b8dcfaee..2430e020 100644 --- a/parser.go +++ b/parser.go @@ -2807,14 +2807,42 @@ func (p *Parser) parseFrameSpec() (_ *FrameSpec, err error) { return &spec, nil } -func (p *Parser) parseParenExpr() (_ *ParenExpr, err error) { - var expr ParenExpr - expr.Lparen, _, _ = p.scan() - if expr.X, err = p.ParseExpr(); err != nil { - return &expr, err +func (p *Parser) parseParenExpr() (Expr, error) { + lparen, _, _ := p.scan() + + // Parse the first expression + x, err := p.ParseExpr() + if err != nil { + return nil, err } - expr.Rparen, _, _ = p.scan() - return &expr, nil + + // If there's no comma after the first expression, treat it as a normal parenthesized expression + if p.peek() != COMMA { + rparen, _, _ := p.scan() + return &ParenExpr{Lparen: lparen, X: x, Rparen: rparen}, nil + } + + // If there's a comma, we're dealing with an expression list + var list ExprList + list.Lparen = lparen + list.Exprs = append(list.Exprs, x) + + for p.peek() == COMMA { + p.scan() // consume the comma + + expr, err := p.ParseExpr() + if err != nil { + return &list, err + } + list.Exprs = append(list.Exprs, expr) + } + + if p.peek() != RP { + return &list, p.errorExpected(p.pos, p.tok, "right paren") + } + list.Rparen, _, _ = p.scan() + + return &list, nil } func (p *Parser) parseCastExpr() (_ *CastExpr, err error) { diff --git a/simple_test.go b/simple_test.go new file mode 100644 index 00000000..74ed2acd --- /dev/null +++ b/simple_test.go @@ -0,0 +1,18 @@ +package sql + +import ( + "strings" + "testing" +) + +func TestSimpleParenList(t *testing.T) { + s := `UPDATE table1 SET col1 = 'value' WHERE (col1, col2) = ('a', 'b')` + stmt, err := NewParser(strings.NewReader(s)).ParseStatement() + if err != nil { + t.Fatalf("failed: %v", err) + } + _, ok := stmt.(*UpdateStatement) + if !ok { + t.Fatalf("failed: expected UpdateStatement") + } +} diff --git a/update_test.go b/update_test.go new file mode 100644 index 00000000..bcb4d8ef --- /dev/null +++ b/update_test.go @@ -0,0 +1,37 @@ +package sql + +import ( + "strings" + "testing" +) + +func TestUpdate(t *testing.T) { + s := `UPDATE asynq_tasks +SET state='active', + pending_since=NULL, + affinity_timeout=server_affinity, + deadline=iif(task_deadline=0, task_timeout+1687276020, task_deadline) +WHERE asynq_tasks.state='pending' + AND (task_uuid, + ndx, + pndx, + task_msg, + task_timeout, + task_deadline)= + (SELECT task_uuid, + ndx, + pndx, + task_msg, + task_timeout, + task_deadline + FROM asynq_tasks) +` + stmt, err := NewParser(strings.NewReader(s)).ParseStatement() + if err != nil { + t.Fatalf("failed: %v", err) + } + _, ok := stmt.(*UpdateStatement) + if !ok { + t.Fatalf("failed: expected UpdateStatement") + } +}