quickstep-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zu...@apache.org
Subject incubator-quickstep git commit: Refactored ScalarCaseExpression.
Date Fri, 04 May 2018 19:59:28 GMT
Repository: incubator-quickstep
Updated Branches:
  refs/heads/master 77287a788 -> 666102fff


Refactored ScalarCaseExpression.


Project: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/commit/666102ff
Tree: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/tree/666102ff
Diff: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/diff/666102ff

Branch: refs/heads/master
Commit: 666102fff32d258a5a2c85e33bf8d8ebb5d3a9cf
Parents: 77287a7
Author: Zuyu Zhang <zuyu@cs.wisc.edu>
Authored: Wed May 2 16:06:59 2018 -0500
Committer: Zuyu Zhang <zuyu@cs.wisc.edu>
Committed: Fri May 4 14:40:28 2018 -0500

----------------------------------------------------------------------
 expressions/scalar/ScalarCaseExpression.cpp     | 265 ++++++++--------
 expressions/scalar/ScalarCaseExpression.hpp     |  26 --
 .../tests/ScalarCaseExpression_unittest.cpp     | 307 +++++++++++++++++++
 3 files changed, 450 insertions(+), 148 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/666102ff/expressions/scalar/ScalarCaseExpression.cpp
----------------------------------------------------------------------
diff --git a/expressions/scalar/ScalarCaseExpression.cpp b/expressions/scalar/ScalarCaseExpression.cpp
index 6847425..c2af83b 100644
--- a/expressions/scalar/ScalarCaseExpression.cpp
+++ b/expressions/scalar/ScalarCaseExpression.cpp
@@ -41,6 +41,102 @@
 
 namespace quickstep {
 
+namespace {
+
+// Merge the values in the NativeColumnVector 'case_result' into '*output' at
+// the positions specified by 'case_matches'. If '*source_sequence' is
+// non-NULL, it indicates which positions actually have tuples in the input,
+// otherwise it is assumed that there are no holes in the input.
+void MultiplexNativeColumnVector(
+    const TupleIdSequence *source_sequence,
+    const TupleIdSequence &case_matches,
+    const NativeColumnVector &case_result,
+    NativeColumnVector *output) {
+  if (source_sequence == nullptr) {
+    if (case_result.typeIsNullable()) {
+      TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
+      for (std::size_t input_pos = 0;
+           input_pos < case_result.size();
+           ++input_pos, ++output_pos_it) {
+        const void *value = case_result.getUntypedValue<true>(input_pos);
+        if (value) {
+          output->positionalWriteUntypedValue(*output_pos_it, value);
+        } else {
+          output->positionalWriteNullValue(*output_pos_it);
+        }
+      }
+    } else {
+      TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
+      for (std::size_t input_pos = 0;
+           input_pos < case_result.size();
+           ++input_pos, ++output_pos_it) {
+        output->positionalWriteUntypedValue(*output_pos_it,
+                                            case_result.getUntypedValue<false>(input_pos));
+      }
+    }
+  } else {
+    if (case_result.typeIsNullable()) {
+      std::size_t input_pos = 0;
+      TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
+      for (std::size_t output_pos = 0;
+           output_pos < output->size();
+           ++output_pos, ++source_sequence_it) {
+        if (case_matches.get(*source_sequence_it)) {
+          const void *value = case_result.getUntypedValue<true>(input_pos++);
+          if (value) {
+            output->positionalWriteUntypedValue(output_pos, value);
+          } else {
+            output->positionalWriteNullValue(output_pos);
+          }
+        }
+      }
+    } else {
+      std::size_t input_pos = 0;
+      TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
+      for (std::size_t output_pos = 0;
+           output_pos < output->size();
+           ++output_pos, ++source_sequence_it) {
+        if (case_matches.get(*source_sequence_it)) {
+          output->positionalWriteUntypedValue(output_pos,
+                                              case_result.getUntypedValue<false>(input_pos++));
+        }
+      }
+    }
+  }
+}
+
+// Same as MultiplexNativeColumnVector(), but works on IndirectColumnVectors
+// instead of NativeColumnVectors.
+void MultiplexIndirectColumnVector(
+    const TupleIdSequence *source_sequence,
+    const TupleIdSequence &case_matches,
+    const IndirectColumnVector &case_result,
+    IndirectColumnVector *output) {
+  if (source_sequence == nullptr) {
+    TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
+    for (std::size_t input_pos = 0;
+         input_pos < case_result.size();
+         ++input_pos, ++output_pos_it) {
+      output->positionalWriteTypedValue(*output_pos_it,
+                                        case_result.getTypedValue(input_pos));
+    }
+  } else {
+    std::size_t input_pos = 0;
+    TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
+    for (std::size_t output_pos = 0;
+         output_pos < output->size();
+         ++output_pos, ++source_sequence_it) {
+      if (case_matches.get(*source_sequence_it)) {
+        output->positionalWriteTypedValue(output_pos,
+                                          case_result.getTypedValue(input_pos++));
+      }
+    }
+  }
+}
+
+}  // namespace
+
+
 ScalarCaseExpression::ScalarCaseExpression(
     const Type &result_type,
     std::vector<std::unique_ptr<Predicate>> &&when_predicates,
@@ -96,17 +192,17 @@ serialization::Scalar ScalarCaseExpression::getProto() const {
   serialization::Scalar proto;
   proto.set_data_source(serialization::Scalar::CASE_EXPRESSION);
   proto.MutableExtension(serialization::ScalarCaseExpression::result_type)
-      ->CopyFrom(type_.getProto());
+      ->MergeFrom(type_.getProto());
   for (const std::unique_ptr<Predicate> &when_pred : when_predicates_) {
     proto.AddExtension(serialization::ScalarCaseExpression::when_predicate)
-      ->CopyFrom(when_pred->getProto());
+        ->MergeFrom(when_pred->getProto());
   }
   for (const std::unique_ptr<Scalar> &result_expr : result_expressions_) {
     proto.AddExtension(serialization::ScalarCaseExpression::result_expression)
-      ->CopyFrom(result_expr->getProto());
+        ->MergeFrom(result_expr->getProto());
   }
   proto.MutableExtension(serialization::ScalarCaseExpression::else_result_expression)
-      ->CopyFrom(else_result_expression_->getProto());
+      ->MergeFrom(else_result_expression_->getProto());
 
   return proto;
 }
@@ -137,16 +233,16 @@ TypedValue ScalarCaseExpression::getValueForSingleTuple(
     return static_value_.makeReferenceToThis();
   } else if (fixed_result_expression_ != nullptr) {
     return fixed_result_expression_->getValueForSingleTuple(accessor, tuple);
-  } else {
-    for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
-         case_idx < when_predicates_.size();
-         ++case_idx) {
-      if (when_predicates_[case_idx]->matchesForSingleTuple(accessor, tuple)) {
-        return result_expressions_[case_idx]->getValueForSingleTuple(accessor, tuple);
-      }
+  }
+
+  for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
+       case_idx < when_predicates_.size();
+       ++case_idx) {
+    if (when_predicates_[case_idx]->matchesForSingleTuple(accessor, tuple)) {
+      return result_expressions_[case_idx]->getValueForSingleTuple(accessor, tuple);
     }
-    return else_result_expression_->getValueForSingleTuple(accessor, tuple);
   }
+  return else_result_expression_->getValueForSingleTuple(accessor, tuple);
 }
 
 TypedValue ScalarCaseExpression::getValueForJoinedTuples(
@@ -165,33 +261,33 @@ TypedValue ScalarCaseExpression::getValueForJoinedTuples(
                                                              right_accessor,
                                                              right_relation_id,
                                                              right_tuple_id);
-  } else {
-    for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
-         case_idx < when_predicates_.size();
-         ++case_idx) {
-      if (when_predicates_[case_idx]->matchesForJoinedTuples(left_accessor,
-                                                             left_relation_id,
-                                                             left_tuple_id,
-                                                             right_accessor,
-                                                             right_relation_id,
-                                                             right_tuple_id)) {
-        return result_expressions_[case_idx]->getValueForJoinedTuples(
-            left_accessor,
-            left_relation_id,
-            left_tuple_id,
-            right_accessor,
-            right_relation_id,
-            right_tuple_id);
-      }
+  }
+
+  for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
+       case_idx < when_predicates_.size();
+       ++case_idx) {
+    if (when_predicates_[case_idx]->matchesForJoinedTuples(left_accessor,
+                                                           left_relation_id,
+                                                           left_tuple_id,
+                                                           right_accessor,
+                                                           right_relation_id,
+                                                           right_tuple_id)) {
+      return result_expressions_[case_idx]->getValueForJoinedTuples(
+          left_accessor,
+          left_relation_id,
+          left_tuple_id,
+          right_accessor,
+          right_relation_id,
+          right_tuple_id);
     }
-    return else_result_expression_->getValueForJoinedTuples(
-        left_accessor,
-        left_relation_id,
-        left_tuple_id,
-        right_accessor,
-        right_relation_id,
-        right_tuple_id);
   }
+  return else_result_expression_->getValueForJoinedTuples(
+      left_accessor,
+      left_relation_id,
+      left_tuple_id,
+      right_accessor,
+      right_relation_id,
+      right_tuple_id);
 }
 
 ColumnVectorPtr ScalarCaseExpression::getAllValues(
@@ -280,6 +376,16 @@ ColumnVectorPtr ScalarCaseExpression::getAllValuesForJoin(
     ValueAccessor *right_accessor,
     const std::vector<std::pair<tuple_id, tuple_id>> &joined_tuple_ids,
     ColumnVectorCache *cv_cache) const {
+  if (has_static_value_) {
+    return ColumnVectorPtr(
+        ColumnVector::MakeVectorOfValue(type_, static_value_, joined_tuple_ids.size()));
+  } else if (fixed_result_expression_) {
+    return fixed_result_expression_->getAllValuesForJoin(
+        left_relation_id, left_accessor,
+        right_relation_id, right_accessor,
+        joined_tuple_ids, cv_cache);
+  }
+
   // Slice 'joined_tuple_ids' apart by case.
   //
   // NOTE(chasseur): We use TupleIdSequence to keep track of the positions in
@@ -368,91 +474,6 @@ ColumnVectorPtr ScalarCaseExpression::getAllValuesForJoin(
       else_results);
 }
 
-void ScalarCaseExpression::MultiplexNativeColumnVector(
-    const TupleIdSequence *source_sequence,
-    const TupleIdSequence &case_matches,
-    const NativeColumnVector &case_result,
-    NativeColumnVector *output) {
-  if (source_sequence == nullptr) {
-    if (case_result.typeIsNullable()) {
-      TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
-      for (std::size_t input_pos = 0;
-           input_pos < case_result.size();
-           ++input_pos, ++output_pos_it) {
-        const void *value = case_result.getUntypedValue<true>(input_pos);
-        if (value == nullptr) {
-          output->positionalWriteNullValue(*output_pos_it);
-        } else {
-          output->positionalWriteUntypedValue(*output_pos_it, value);
-        }
-      }
-    } else {
-      TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
-      for (std::size_t input_pos = 0;
-           input_pos < case_result.size();
-           ++input_pos, ++output_pos_it) {
-        output->positionalWriteUntypedValue(*output_pos_it,
-                                            case_result.getUntypedValue<false>(input_pos));
-      }
-    }
-  } else {
-    if (case_result.typeIsNullable()) {
-      std::size_t input_pos = 0;
-      TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
-      for (std::size_t output_pos = 0;
-           output_pos < output->size();
-           ++output_pos, ++source_sequence_it) {
-        if (case_matches.get(*source_sequence_it)) {
-          const void *value = case_result.getUntypedValue<true>(input_pos++);
-          if (value == nullptr) {
-            output->positionalWriteNullValue(output_pos);
-          } else {
-            output->positionalWriteUntypedValue(output_pos, value);
-          }
-        }
-      }
-    } else {
-      std::size_t input_pos = 0;
-      TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
-      for (std::size_t output_pos = 0;
-           output_pos < output->size();
-           ++output_pos, ++source_sequence_it) {
-        if (case_matches.get(*source_sequence_it)) {
-          output->positionalWriteUntypedValue(output_pos,
-                                              case_result.getUntypedValue<false>(input_pos++));
-        }
-      }
-    }
-  }
-}
-
-void ScalarCaseExpression::MultiplexIndirectColumnVector(
-    const TupleIdSequence *source_sequence,
-    const TupleIdSequence &case_matches,
-    const IndirectColumnVector &case_result,
-    IndirectColumnVector *output) {
-  if (source_sequence == nullptr) {
-    TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
-    for (std::size_t input_pos = 0;
-         input_pos < case_result.size();
-         ++input_pos, ++output_pos_it) {
-      output->positionalWriteTypedValue(*output_pos_it,
-                                        case_result.getTypedValue(input_pos));
-    }
-  } else {
-    std::size_t input_pos = 0;
-    TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
-    for (std::size_t output_pos = 0;
-         output_pos < output->size();
-         ++output_pos, ++source_sequence_it) {
-      if (case_matches.get(*source_sequence_it)) {
-        output->positionalWriteTypedValue(output_pos,
-                                          case_result.getTypedValue(input_pos++));
-      }
-    }
-  }
-}
-
 ColumnVectorPtr ScalarCaseExpression::multiplexColumnVectors(
     const std::size_t output_size,
     const TupleIdSequence *source_sequence,

http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/666102ff/expressions/scalar/ScalarCaseExpression.hpp
----------------------------------------------------------------------
diff --git a/expressions/scalar/ScalarCaseExpression.hpp b/expressions/scalar/ScalarCaseExpression.hpp
index 3d0ed71..22acfa8 100644
--- a/expressions/scalar/ScalarCaseExpression.hpp
+++ b/expressions/scalar/ScalarCaseExpression.hpp
@@ -124,14 +124,6 @@ class ScalarCaseExpression : public Scalar {
     }
   }
 
-  relation_id getRelationIdForValueAccessor() const override {
-    if (fixed_result_expression_ != nullptr) {
-      return fixed_result_expression_->getRelationIdForValueAccessor();
-    } else {
-      return -1;
-    }
-  }
-
   ColumnVectorPtr getAllValues(ValueAccessor *accessor,
                                const SubBlocksReference *sub_blocks_ref,
                                ColumnVectorCache *cv_cache) const override;
@@ -154,24 +146,6 @@ class ScalarCaseExpression : public Scalar {
       std::vector<std::vector<const Expression*>> *container_child_fields) const
override;
 
  private:
-  // Merge the values in the NativeColumnVector 'case_result' into '*output' at
-  // the positions specified by 'case_matches'. If '*source_sequence' is
-  // non-NULL, it indicates which positions actually have tuples in the input,
-  // otherwise it is assumed that there are no holes in the input.
-  static void MultiplexNativeColumnVector(
-      const TupleIdSequence *source_sequence,
-      const TupleIdSequence &case_matches,
-      const NativeColumnVector &case_result,
-      NativeColumnVector *output);
-
-  // Same as MultiplexNativeColumnVector(), but works on IndirectColumnVectors
-  // instead of NativeColumnVectors.
-  static void MultiplexIndirectColumnVector(
-      const TupleIdSequence *source_sequence,
-      const TupleIdSequence &case_matches,
-      const IndirectColumnVector &case_result,
-      IndirectColumnVector *output);
-
   // Create and return a new ColumnVector by multiplexing the ColumnVectors
   // containing results for individual CASE branches at the appropriate
   // positions. 'output_size' is the total number of values in the output.

http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/666102ff/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
----------------------------------------------------------------------
diff --git a/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp b/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
index 7182642..f385b74 100644
--- a/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
+++ b/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
@@ -875,6 +875,313 @@ TEST_F(ScalarCaseExpressionTest,
   }
 }
 
+// Test CASE evaluation over joins, which that always goes to the same branch
+// on a constant.
+TEST_F(ScalarCaseExpressionTest, JoinStaticBranchConstantTest) {
+  // Simulate a join with another relation.
+  CatalogRelation other_relation(nullptr, "other", 1);
+  other_relation.addAttribute(new CatalogAttribute(&other_relation,
+                                                   "other_double",
+                                                   TypeFactory::GetType(kDouble, false)));
+  other_relation.addAttribute(new CatalogAttribute(&other_relation,
+                                                   "other_int",
+                                                   TypeFactory::GetType(kInt, false)));
+
+  static const double kOtherDoubleValues[] = {-250.0, -750.0};
+  std::unique_ptr<NativeColumnVector> other_double_column(
+      new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2));
+  other_double_column->appendUntypedValue(kOtherDoubleValues);
+  other_double_column->appendUntypedValue(kOtherDoubleValues + 1);
+
+  static const int kOtherIntValues[] = {10, -1};
+  std::unique_ptr<NativeColumnVector> other_int_column(
+      new NativeColumnVector(TypeFactory::GetType(kInt, false), 2));
+  other_int_column->appendUntypedValue(kOtherIntValues);
+  other_int_column->appendUntypedValue(kOtherIntValues + 1);
+
+  ColumnVectorsValueAccessor other_accessor;
+  other_accessor.addColumn(other_double_column.release());
+  other_accessor.addColumn(other_int_column.release());
+
+  const Type &int_type = TypeFactory::GetType(kInt);
+
+  // Setup expression.
+  std::vector<std::unique_ptr<Predicate>> when_predicates;
+  std::vector<std::unique_ptr<Scalar>> result_expressions;
+
+  // WHEN 1 > 2 THEN int_attr + other_int
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kGreater),
+      new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+      new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+  result_expressions.emplace_back(new ScalarBinaryExpression(
+      BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd),
+      new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+      new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+  const int kConstant = 72;
+  // WHEN 1 < 2 THEN kConstant
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kLess),
+      new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+      new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+  result_expressions.emplace_back(
+      new ScalarLiteral(TypedValue(kConstant), int_type));
+
+  // WHEN double_attr = other_double THEN 0
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kEqual),
+      new ScalarAttribute(*sample_relation_->getAttributeById(1)),
+      new ScalarAttribute(*other_relation.getAttributeById(0))));
+  result_expressions.emplace_back(new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt)));
+
+  const Type &int_nullable_type = TypeFactory::GetType(kInt, true);
+
+  // ELSE NULL
+  ScalarCaseExpression case_expr(
+      int_nullable_type,
+      std::move(when_predicates),
+      std::move(result_expressions),
+      new ScalarLiteral(TypedValue(kInt), int_nullable_type));
+
+  // Create a list of joined tuple-id pairs (just the cross-product of tuples).
+  std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids;
+  for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) {
+    joined_tuple_ids.emplace_back(tuple_num, 0);
+    joined_tuple_ids.emplace_back(tuple_num, 1);
+  }
+
+  ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin(
+      0,
+      &sample_data_value_accessor_,
+      1,
+      &other_accessor,
+      joined_tuple_ids,
+      nullptr /* cv_cache */));
+  ASSERT_TRUE(result_cv->isNative());
+  const NativeColumnVector &native_result_cv
+      = static_cast<const NativeColumnVector&>(*result_cv);
+  EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size());
+
+  for (std::size_t result_num = 0;
+       result_num < native_result_cv.size();
+       ++result_num) {
+    EXPECT_EQ(kConstant,
+              *static_cast<const int*>(native_result_cv.getUntypedValue(result_num)));
+  }
+}
+
+// Test CASE evaluation over joins, which that always goes to the same branch
+// of ScalarAttribute.
+TEST_F(ScalarCaseExpressionTest, JoinStaticBranchOnScalarAttributeTest) {
+  // Simulate a join with another relation.
+  CatalogRelation other_relation(nullptr, "other", 1);
+  other_relation.addAttribute(new CatalogAttribute(&other_relation,
+                                                   "other_double",
+                                                   TypeFactory::GetType(kDouble, false)));
+  other_relation.addAttribute(new CatalogAttribute(&other_relation,
+                                                   "other_int",
+                                                   TypeFactory::GetType(kInt, false)));
+
+  static const double kOtherDoubleValues[] = {-250.0, -750.0};
+  std::unique_ptr<NativeColumnVector> other_double_column(
+      new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2));
+  other_double_column->appendUntypedValue(kOtherDoubleValues);
+  other_double_column->appendUntypedValue(kOtherDoubleValues + 1);
+
+  static const int kOtherIntValues[] = {10, -1};
+  std::unique_ptr<NativeColumnVector> other_int_column(
+      new NativeColumnVector(TypeFactory::GetType(kInt, false), 2));
+  other_int_column->appendUntypedValue(kOtherIntValues);
+  other_int_column->appendUntypedValue(kOtherIntValues + 1);
+
+  ColumnVectorsValueAccessor other_accessor;
+  other_accessor.addColumn(other_double_column.release());
+  other_accessor.addColumn(other_int_column.release());
+
+  const Type &int_type = TypeFactory::GetType(kInt);
+
+  // Setup expression.
+  std::vector<std::unique_ptr<Predicate>> when_predicates;
+  std::vector<std::unique_ptr<Scalar>> result_expressions;
+
+  // WHEN 1 > 2 THEN int_attr + other_int
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kGreater),
+      new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+      new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+  result_expressions.emplace_back(new ScalarBinaryExpression(
+      BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd),
+      new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+      new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+  // WHEN 1 < 2 THEN int_attr
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kLess),
+      new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+      new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+  result_expressions.emplace_back(
+      new ScalarAttribute(*sample_relation_->getAttributeById(0)));
+
+  // WHEN double_attr = other_double THEN 0
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kEqual),
+      new ScalarAttribute(*sample_relation_->getAttributeById(1)),
+      new ScalarAttribute(*other_relation.getAttributeById(0))));
+  result_expressions.emplace_back(new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt)));
+
+  const Type &int_nullable_type = TypeFactory::GetType(kInt, true);
+
+  // ELSE NULL
+  ScalarCaseExpression case_expr(
+      int_nullable_type,
+      std::move(when_predicates),
+      std::move(result_expressions),
+      new ScalarLiteral(TypedValue(kInt), int_nullable_type));
+
+  // Create a list of joined tuple-id pairs (just the cross-product of tuples).
+  std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids;
+  for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) {
+    joined_tuple_ids.emplace_back(tuple_num, 0);
+    joined_tuple_ids.emplace_back(tuple_num, 1);
+  }
+
+  ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin(
+      0,
+      &sample_data_value_accessor_,
+      1,
+      &other_accessor,
+      joined_tuple_ids,
+      nullptr /* cv_cache */));
+  ASSERT_TRUE(result_cv->isNative());
+  const NativeColumnVector &native_result_cv
+      = static_cast<const NativeColumnVector&>(*result_cv);
+  EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size());
+
+  for (std::size_t result_num = 0;
+       result_num < native_result_cv.size();
+       ++result_num) {
+    // For convenience, calculate expected tuple values here.
+    const bool sample_int_null = ((result_num >> 1) % 10 == 0);
+    const int sample_int = result_num >> 1;
+
+    if (sample_int_null) {
+      EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(result_num));
+    } else {
+      ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num));
+      EXPECT_EQ(sample_int,
+                *static_cast<const int*>(native_result_cv.getUntypedValue(result_num)));
+    }
+  }
+}
+
+// Test CASE evaluation over joins, which that always goes to the same branch
+// of ScalarBinaryExpression.
+TEST_F(ScalarCaseExpressionTest, JoinStaticBranchTest) {
+  // Simulate a join with another relation.
+  CatalogRelation other_relation(nullptr, "other", 1);
+  other_relation.addAttribute(new CatalogAttribute(&other_relation,
+                                                   "other_double",
+                                                   TypeFactory::GetType(kDouble, false)));
+  other_relation.addAttribute(new CatalogAttribute(&other_relation,
+                                                   "other_int",
+                                                   TypeFactory::GetType(kInt, false)));
+
+  static const double kOtherDoubleValues[] = {-250.0, -750.0};
+  std::unique_ptr<NativeColumnVector> other_double_column(
+      new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2));
+  other_double_column->appendUntypedValue(kOtherDoubleValues);
+  other_double_column->appendUntypedValue(kOtherDoubleValues + 1);
+
+  static const int kOtherIntValues[] = {10, -1};
+  std::unique_ptr<NativeColumnVector> other_int_column(
+      new NativeColumnVector(TypeFactory::GetType(kInt, false), 2));
+  other_int_column->appendUntypedValue(kOtherIntValues);
+  other_int_column->appendUntypedValue(kOtherIntValues + 1);
+
+  ColumnVectorsValueAccessor other_accessor;
+  other_accessor.addColumn(other_double_column.release());
+  other_accessor.addColumn(other_int_column.release());
+
+  const Type &int_type = TypeFactory::GetType(kInt);
+
+  // Setup expression.
+  std::vector<std::unique_ptr<Predicate>> when_predicates;
+  std::vector<std::unique_ptr<Scalar>> result_expressions;
+
+  // WHEN 1 > 2 THEN int_attr + other_int
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kGreater),
+      new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+      new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+  result_expressions.emplace_back(new ScalarBinaryExpression(
+      BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd),
+      new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+      new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+  // WHEN 1 < 2 THEN int_attr * other_int
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kLess),
+      new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+      new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+  result_expressions.emplace_back(new ScalarBinaryExpression(
+      BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply),
+      new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+      new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+  // WHEN double_attr = other_double THEN 0
+  when_predicates.emplace_back(new ComparisonPredicate(
+      ComparisonFactory::GetComparison(ComparisonID::kEqual),
+      new ScalarAttribute(*sample_relation_->getAttributeById(1)),
+      new ScalarAttribute(*other_relation.getAttributeById(0))));
+  result_expressions.emplace_back(new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt)));
+
+  const Type &int_nullable_type = TypeFactory::GetType(kInt, true);
+
+  // ELSE NULL
+  ScalarCaseExpression case_expr(
+      int_nullable_type,
+      std::move(when_predicates),
+      std::move(result_expressions),
+      new ScalarLiteral(TypedValue(kInt), int_nullable_type));
+
+  // Create a list of joined tuple-id pairs (just the cross-product of tuples).
+  std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids;
+  for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) {
+    joined_tuple_ids.emplace_back(tuple_num, 0);
+    joined_tuple_ids.emplace_back(tuple_num, 1);
+  }
+
+  ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin(
+      0,
+      &sample_data_value_accessor_,
+      1,
+      &other_accessor,
+      joined_tuple_ids,
+      nullptr /* cv_cache */));
+  ASSERT_TRUE(result_cv->isNative());
+  const NativeColumnVector &native_result_cv
+      = static_cast<const NativeColumnVector&>(*result_cv);
+  EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size());
+
+  for (std::size_t result_num = 0;
+       result_num < native_result_cv.size();
+       ++result_num) {
+    // For convenience, calculate expected tuple values here.
+    const bool sample_int_null = ((result_num >> 1) % 10 == 0);
+    const int sample_int = result_num >> 1;
+    const int other_int = kOtherIntValues[result_num & 0x1];
+
+    if (sample_int_null) {
+      EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(result_num));
+    } else {
+      ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num));
+      EXPECT_EQ(sample_int * other_int,
+                *static_cast<const int*>(native_result_cv.getUntypedValue(result_num)));
+    }
+  }
+}
+
 // Test CASE evaluation over joins, with both WHEN predicates and THEN
 // expressions referencing attributes in both relations.
 TEST_F(ScalarCaseExpressionTest, JoinTest) {


Mime
View raw message