diff --git a/src/parser/bison_parser.y b/src/parser/bison_parser.y index 7d280a9..8163db2 100755 --- a/src/parser/bison_parser.y +++ b/src/parser/bison_parser.y @@ -195,7 +195,7 @@ int yyerror(YYLTYPE* llocp, SQLParserResult* result, yyscan_t scanner, const cha %type expr operand scalar_expr unary_expr binary_expr logic_expr exists_expr %type function_expr between_expr expr_alias param_expr %type column_name literal int_literal num_literal string_literal -%type comp_expr opt_where join_condition opt_having case_expr in_expr hint +%type comp_expr opt_where join_condition opt_having case_expr case_list in_expr hint %type array_expr array_index null_literal %type opt_limit opt_top %type order_desc @@ -781,11 +781,18 @@ in_expr: | operand NOT IN '(' select_no_paren ')' { $$ = Expr::makeOpUnary(kOpNot, Expr::makeInOperator($1, $5)); } ; -// TODO: allow no else specified +// CASE grammar based on: flex & bison by John Levine +// https://www.safaribooksonline.com/library/view/flex-bison/9780596805418/ch04.html#id352665 case_expr: - CASE WHEN expr THEN operand END { $$ = Expr::makeCase($3, $5); } - | - CASE WHEN expr THEN operand ELSE operand END { $$ = Expr::makeCase($3, $5, $7); } + CASE expr case_list END { $$ = Expr::makeCaseExpr($2, $3); } + | CASE expr case_list ELSE expr END { $$ = Expr::makeCaseExpr($2, $3, $5); } + | CASE case_list END { $$ = Expr::makeCase($2); } + | CASE case_list ELSE expr END { $$ = Expr::makeCase($2, $4); } + ; + +case_list: + WHEN expr THEN expr { $$ = Expr::makeCaseCondition($2, $4); } + | case_list WHEN expr THEN expr { $$ = Expr::joinCaseCondition($1, Expr::makeCaseCondition($3, $5)); } ; exists_expr: diff --git a/src/sql/Expr.cpp b/src/sql/Expr.cpp index 51632e3..9cc05a4 100644 --- a/src/sql/Expr.cpp +++ b/src/sql/Expr.cpp @@ -68,22 +68,53 @@ namespace hsql { return e; } - Expr* Expr::makeCase(Expr* expr, Expr* then) { - Expr* e = new Expr(kExprOperator); + Expr* Expr::makeCaseCondition(Expr* expr, Expr* then) { + Expr* e = new Expr(kExprWhenCondition); e->expr = expr; - e->opType = kOpCase; - e->exprList = new std::vector(); - e->exprList->push_back(then); + e->expr2 = then; return e; } - Expr* Expr::makeCase(Expr* expr, Expr* then, Expr* other) { + Expr* Expr::joinCaseCondition(Expr* expr1, Expr* expr2) { + Expr* e = new Expr(kExprOperator); + e->opType = kOpPlus; + if (expr1->exprList != nullptr) { + e->exprList = expr1->exprList; + } else { + e->exprList = new std::vector(); + e->exprList->push_back(expr1); + } + e->exprList->push_back(expr2); + return e; + } + + Expr* Expr::makeCase(Expr* when) { Expr* e = new Expr(kExprOperator); - e->expr = expr; e->opType = kOpCase; - e->exprList = new std::vector(); - e->exprList->push_back(then); - e->exprList->push_back(other); + if (when->exprList != nullptr) { + e->exprList = when->exprList; + } else { + e->exprList = new std::vector(); + e->exprList->push_back(when); + } + return e; + } + + Expr* Expr::makeCase(Expr* when, Expr* other) { + Expr* e = Expr::makeCase(when); + e->expr2 = other; + return e; + } + + Expr* Expr::makeCaseExpr(Expr* expr, Expr* when) { + Expr* e = Expr::makeCase(when); + e->expr = expr; + return e; + } + + Expr* Expr::makeCaseExpr(Expr* expr, Expr* when, Expr* other) { + Expr* e = Expr::makeCase(when, other); + e->expr = expr; return e; } diff --git a/src/sql/Expr.h b/src/sql/Expr.h index dcaddee..e4ede07 100644 --- a/src/sql/Expr.h +++ b/src/sql/Expr.h @@ -25,7 +25,8 @@ namespace hsql { kExprSelect, kExprHint, kExprArray, - kExprArrayIndex + kExprArrayIndex, + kExprWhenCondition }; // Operator types. These are important for expressions of type kExprOperator. @@ -113,9 +114,17 @@ namespace hsql { static Expr* makeBetween(Expr* expr, Expr* left, Expr* right); - static Expr* makeCase(Expr* expr, Expr* then); + static Expr* makeCaseCondition(Expr* expr, Expr* then); - static Expr* makeCase(Expr* expr, Expr* then, Expr* other); + static Expr* joinCaseCondition(Expr* expr, Expr* then); + + static Expr* makeCase(Expr* when); + + static Expr* makeCase(Expr* when, Expr* other); + + static Expr* makeCaseExpr(Expr* expr, Expr* when); + + static Expr* makeCaseExpr(Expr* expr, Expr* when, Expr* other); static Expr* makeLiteral(int64_t val); diff --git a/test/select_tests.cpp b/test/select_tests.cpp index 2d25b92..5353201 100644 --- a/test/select_tests.cpp +++ b/test/select_tests.cpp @@ -214,9 +214,74 @@ TEST(SelectCaseWhen) { ASSERT_NOTNULL(caseExpr); ASSERT(caseExpr->isType(kExprOperator)); ASSERT_EQ(caseExpr->opType, kOpCase); - ASSERT(caseExpr->expr->isType(kExprOperator)); - ASSERT_EQ(caseExpr->expr->opType, kOpEquals); + ASSERT_NULL(caseExpr->expr); + ASSERT_NOTNULL(caseExpr->exprList); + ASSERT_NOTNULL(caseExpr->expr2); + ASSERT_EQ(caseExpr->exprList->size(), 1); + ASSERT(caseExpr->expr2->isType(kExprLiteralInt)); + + Expr* whenExpr = caseExpr->exprList->at(0); + ASSERT(whenExpr->expr->isType(kExprOperator)); + ASSERT_EQ(whenExpr->expr->opType, kOpEquals); + ASSERT(whenExpr->expr->expr->isType(kExprColumnRef)); + ASSERT(whenExpr->expr->expr2->isType(kExprLiteralString)); +} + +TEST(SelectCaseWhenWhen) { + TEST_PARSE_SINGLE_SQL( + "SELECT CASE WHEN x = 1 THEN 1 WHEN 1.25 < x THEN 2 END FROM test;", + kStmtSelect, + SelectStatement, + result, + stmt); + + ASSERT_EQ(stmt->selectList->size(), 1); + Expr* caseExpr = stmt->selectList->at(0); + ASSERT_NOTNULL(caseExpr); + ASSERT(caseExpr->isType(kExprOperator)); + ASSERT_EQ(caseExpr->opType, kOpCase); + // CASE [expr] [exprList] [expr2] + // [expr] [expr] + // [expr] [expr2] [expr] [expr2] + // CASE (null) WHEN X = 1 THEN 1 WHEN 1.25 < x THEN 2 (null) + ASSERT_NULL(caseExpr->expr); + ASSERT_NOTNULL(caseExpr->exprList); + ASSERT_NULL(caseExpr->expr2); ASSERT_EQ(caseExpr->exprList->size(), 2); + + Expr* whenExpr = caseExpr->exprList->at(0); + ASSERT_EQ(whenExpr->expr->opType, kOpEquals); + ASSERT(whenExpr->expr->expr->isType(kExprColumnRef)); + ASSERT(whenExpr->expr->expr2->isType(kExprLiteralInt)); + + Expr* whenExpr2 = caseExpr->exprList->at(1); + ASSERT_EQ(whenExpr2->expr->opType, kOpLess); + ASSERT(whenExpr2->expr->expr->isType(kExprLiteralFloat)); + ASSERT(whenExpr2->expr->expr2->isType(kExprColumnRef)); +} + +TEST(SelectCaseValueWhenWhenElse) { + TEST_PARSE_SINGLE_SQL( + "SELECT CASE x WHEN 1 THEN 0 WHEN 2 THEN 3 WHEN 8 THEN 7 ELSE 4 END FROM test;", + kStmtSelect, + SelectStatement, + result, + stmt); + + ASSERT_EQ(stmt->selectList->size(), 1); + Expr* caseExpr = stmt->selectList->at(0); + ASSERT_NOTNULL(caseExpr); + ASSERT(caseExpr->isType(kExprOperator)); + ASSERT_EQ(caseExpr->opType, kOpCase); + ASSERT_NOTNULL(caseExpr->expr); + ASSERT_NOTNULL(caseExpr->exprList); + ASSERT_NOTNULL(caseExpr->expr2); + ASSERT_EQ(caseExpr->exprList->size(), 3); + ASSERT(caseExpr->expr->isType(kExprColumnRef)); + + Expr* whenExpr = caseExpr->exprList->at(2); + ASSERT(whenExpr->expr->isType(kExprLiteralInt)); + ASSERT_EQ(whenExpr->expr2->ival, 7); } TEST(SelectJoin) {