spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-18891][SQL] Support for Scala Map collection types
Date Mon, 12 Jun 2017 00:47:10 GMT
Repository: spark
Updated Branches:
  refs/heads/master a7c61c100 -> 0538f3b0a


[SPARK-18891][SQL] Support for Scala Map collection types

## What changes were proposed in this pull request?

Add support for arbitrary Scala `Map` types in deserialization as well as a generic implicit encoder.

Used the builder approach as in #16541 to construct any provided `Map` type upon deserialization.

Please note that this PR also adds (ignored) tests for issue [SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it.

Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes:

* `java.util.Map`, `java.util.AbstractMap` => `java.util.HashMap`
* `java.util.SortedMap`, `java.util.NavigableMap` => `java.util.TreeMap`
* `java.util.concurrent.ConcurrentMap` => `java.util.concurrent.ConcurrentHashMap`
* `java.util.concurrent.ConcurrentNavigableMap` => `java.util.concurrent.ConcurrentSkipListMap`

Resulting codegen for `Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen`:

```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */   private int CollectObjectsToMap_loopValue0;
/* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */   private int CollectObjectsToMap_loopValue2;
/* 013 */   private UnsafeRow deserializetoobject_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */   private scala.collection.immutable.Map mapelements_argValue;
/* 017 */   private UnsafeRow mapelements_result;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */   private UnsafeRow serializefromobject_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */   public GeneratedIterator(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */     wholestagecodegen_init_0();
/* 034 */     wholestagecodegen_init_1();
/* 035 */
/* 036 */   }
/* 037 */
/* 038 */   private void wholestagecodegen_init_0() {
/* 039 */     inputadapter_input = inputs[0];
/* 040 */
/* 041 */     deserializetoobject_result = new UnsafeRow(1);
/* 042 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */     mapelements_result = new UnsafeRow(1);
/* 046 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */     serializefromobject_result = new UnsafeRow(1);
/* 049 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */   }
/* 054 */
/* 055 */   private void wholestagecodegen_init_1() {
/* 056 */     this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */       boolean deserializetoobject_isNull1 = true;
/* 067 */       ArrayData deserializetoobject_value1 = null;
/* 068 */       if (!inputadapter_isNull) {
/* 069 */         deserializetoobject_isNull1 = false;
/* 070 */         if (!deserializetoobject_isNull1) {
/* 071 */           Object deserializetoobject_funcResult = null;
/* 072 */           deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */           if (deserializetoobject_funcResult == null) {
/* 074 */             deserializetoobject_isNull1 = true;
/* 075 */           } else {
/* 076 */             deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */           }
/* 078 */
/* 079 */         }
/* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */       }
/* 082 */
/* 083 */       boolean deserializetoobject_isNull3 = true;
/* 084 */       ArrayData deserializetoobject_value3 = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull3 = false;
/* 087 */         if (!deserializetoobject_isNull3) {
/* 088 */           Object deserializetoobject_funcResult1 = null;
/* 089 */           deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */           if (deserializetoobject_funcResult1 == null) {
/* 091 */             deserializetoobject_isNull3 = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */       }
/* 099 */       scala.collection.immutable.Map deserializetoobject_value = null;
/* 100 */
/* 101 */       if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */         (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */         throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */       }
/* 105 */
/* 106 */       if (!deserializetoobject_isNull1) {
/* 107 */         if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */           throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */         }
/* 110 */         int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */
/* 112 */         scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder();
/* 113 */         CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength);
/* 114 */
/* 115 */         int deserializetoobject_loopIndex = 0;
/* 116 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 117 */           CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 118 */           CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 119 */           CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 120 */           CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 121 */
/* 122 */           if (CollectObjectsToMap_loopIsNull1) {
/* 123 */             throw new RuntimeException("Found null in map key!");
/* 124 */           }
/* 125 */
/* 126 */           scala.Tuple2 CollectObjectsToMap_loopValue4;
/* 127 */
/* 128 */           if (CollectObjectsToMap_loopIsNull3) {
/* 129 */             CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null);
/* 130 */           } else {
/* 131 */             CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 132 */           }
/* 133 */
/* 134 */           CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4);
/* 135 */
/* 136 */           deserializetoobject_loopIndex += 1;
/* 137 */         }
/* 138 */
/* 139 */         deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result();
/* 140 */       }
/* 141 */
/* 142 */       boolean mapelements_isNull = true;
/* 143 */       scala.collection.immutable.Map mapelements_value = null;
/* 144 */       if (!false) {
/* 145 */         mapelements_argValue = deserializetoobject_value;
/* 146 */
/* 147 */         mapelements_isNull = false;
/* 148 */         if (!mapelements_isNull) {
/* 149 */           Object mapelements_funcResult = null;
/* 150 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 151 */           if (mapelements_funcResult == null) {
/* 152 */             mapelements_isNull = true;
/* 153 */           } else {
/* 154 */             mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult;
/* 155 */           }
/* 156 */
/* 157 */         }
/* 158 */         mapelements_isNull = mapelements_value == null;
/* 159 */       }
/* 160 */
/* 161 */       MapData serializefromobject_value = null;
/* 162 */       if (!mapelements_isNull) {
/* 163 */         final int serializefromobject_length = mapelements_value.size();
/* 164 */         final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 165 */         final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 166 */         int serializefromobject_index = 0;
/* 167 */         final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator();
/* 168 */         while(serializefromobject_entries.hasNext()) {
/* 169 */           final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next();
/* 170 */           int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1();
/* 171 */           int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2();
/* 172 */
/* 173 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 174 */
/* 175 */           if (false) {
/* 176 */             throw new RuntimeException("Cannot use null as map key!");
/* 177 */           } else {
/* 178 */             serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 179 */           }
/* 180 */
/* 181 */           if (false) {
/* 182 */             serializefromobject_convertedValues[serializefromobject_index] = null;
/* 183 */           } else {
/* 184 */             serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 185 */           }
/* 186 */
/* 187 */           serializefromobject_index++;
/* 188 */         }
/* 189 */
/* 190 */         serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 191 */       }
/* 192 */       serializefromobject_holder.reset();
/* 193 */
/* 194 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 195 */
/* 196 */       if (mapelements_isNull) {
/* 197 */         serializefromobject_rowWriter.setNullAt(0);
/* 198 */       } else {
/* 199 */         // Remember the current cursor so that we can calculate how many bytes are
/* 200 */         // written later.
/* 201 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 202 */
/* 203 */         if (serializefromobject_value instanceof UnsafeMapData) {
/* 204 */           final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 205 */           // grow the global buffer before writing data.
/* 206 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 207 */           ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 208 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 209 */
/* 210 */         } else {
/* 211 */           final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 212 */           final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 213 */
/* 214 */           // preserve 8 bytes to write the key array numBytes later.
/* 215 */           serializefromobject_holder.grow(8);
/* 216 */           serializefromobject_holder.cursor += 8;
/* 217 */
/* 218 */           // Remember the current cursor so that we can write numBytes of key array later.
/* 219 */           final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 220 */
/* 221 */           if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 222 */             final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 223 */             // grow the global buffer before writing data.
/* 224 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 225 */             ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 226 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 227 */
/* 228 */           } else {
/* 229 */             final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 230 */             serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 231 */
/* 232 */             for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 233 */               if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 234 */                 serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 235 */               } else {
/* 236 */                 final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 237 */                 serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 238 */               }
/* 239 */             }
/* 240 */           }
/* 241 */
/* 242 */           // Write the numBytes of key array into the first 8 bytes.
/* 243 */           Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 244 */
/* 245 */           if (serializefromobject_values instanceof UnsafeArrayData) {
/* 246 */             final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 247 */             // grow the global buffer before writing data.
/* 248 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 249 */             ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 250 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 251 */
/* 252 */           } else {
/* 253 */             final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 254 */             serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 255 */
/* 256 */             for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 257 */               if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 258 */                 serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 259 */               } else {
/* 260 */                 final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 261 */                 serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 262 */               }
/* 263 */             }
/* 264 */           }
/* 265 */
/* 266 */         }
/* 267 */
/* 268 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 269 */       }
/* 270 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 271 */       append(serializefromobject_result);
/* 272 */       if (shouldStop()) return;
/* 273 */     }
/* 274 */   }
/* 275 */ }
```

Codegen for `java.util.Map`:

```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */   private int CollectObjectsToMap_loopValue0;
/* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */   private int CollectObjectsToMap_loopValue2;
/* 013 */   private UnsafeRow deserializetoobject_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */   private java.util.HashMap mapelements_argValue;
/* 017 */   private UnsafeRow mapelements_result;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */   private UnsafeRow serializefromobject_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */   public GeneratedIterator(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */     wholestagecodegen_init_0();
/* 034 */     wholestagecodegen_init_1();
/* 035 */
/* 036 */   }
/* 037 */
/* 038 */   private void wholestagecodegen_init_0() {
/* 039 */     inputadapter_input = inputs[0];
/* 040 */
/* 041 */     deserializetoobject_result = new UnsafeRow(1);
/* 042 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */     mapelements_result = new UnsafeRow(1);
/* 046 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */     serializefromobject_result = new UnsafeRow(1);
/* 049 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */   }
/* 054 */
/* 055 */   private void wholestagecodegen_init_1() {
/* 056 */     this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */       boolean deserializetoobject_isNull1 = true;
/* 067 */       ArrayData deserializetoobject_value1 = null;
/* 068 */       if (!inputadapter_isNull) {
/* 069 */         deserializetoobject_isNull1 = false;
/* 070 */         if (!deserializetoobject_isNull1) {
/* 071 */           Object deserializetoobject_funcResult = null;
/* 072 */           deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */           if (deserializetoobject_funcResult == null) {
/* 074 */             deserializetoobject_isNull1 = true;
/* 075 */           } else {
/* 076 */             deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */           }
/* 078 */
/* 079 */         }
/* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */       }
/* 082 */
/* 083 */       boolean deserializetoobject_isNull3 = true;
/* 084 */       ArrayData deserializetoobject_value3 = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull3 = false;
/* 087 */         if (!deserializetoobject_isNull3) {
/* 088 */           Object deserializetoobject_funcResult1 = null;
/* 089 */           deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */           if (deserializetoobject_funcResult1 == null) {
/* 091 */             deserializetoobject_isNull3 = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */       }
/* 099 */       java.util.HashMap deserializetoobject_value = null;
/* 100 */
/* 101 */       if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */         (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */         throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */       }
/* 105 */
/* 106 */       if (!deserializetoobject_isNull1) {
/* 107 */         if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */           throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */         }
/* 110 */         int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */         java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength);
/* 112 */
/* 113 */         int deserializetoobject_loopIndex = 0;
/* 114 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 115 */           CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 116 */           CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 117 */           CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 118 */           CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 119 */
/* 120 */           if (CollectObjectsToMap_loopIsNull1) {
/* 121 */             throw new RuntimeException("Found null in map key!");
/* 122 */           }
/* 123 */
/* 124 */           CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 125 */
/* 126 */           deserializetoobject_loopIndex += 1;
/* 127 */         }
/* 128 */
/* 129 */         deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5;
/* 130 */       }
/* 131 */
/* 132 */       boolean mapelements_isNull = true;
/* 133 */       java.util.HashMap mapelements_value = null;
/* 134 */       if (!false) {
/* 135 */         mapelements_argValue = deserializetoobject_value;
/* 136 */
/* 137 */         mapelements_isNull = false;
/* 138 */         if (!mapelements_isNull) {
/* 139 */           Object mapelements_funcResult = null;
/* 140 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 141 */           if (mapelements_funcResult == null) {
/* 142 */             mapelements_isNull = true;
/* 143 */           } else {
/* 144 */             mapelements_value = (java.util.HashMap) mapelements_funcResult;
/* 145 */           }
/* 146 */
/* 147 */         }
/* 148 */         mapelements_isNull = mapelements_value == null;
/* 149 */       }
/* 150 */
/* 151 */       MapData serializefromobject_value = null;
/* 152 */       if (!mapelements_isNull) {
/* 153 */         final int serializefromobject_length = mapelements_value.size();
/* 154 */         final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 155 */         final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 156 */         int serializefromobject_index = 0;
/* 157 */         final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator();
/* 158 */         while(serializefromobject_entries.hasNext()) {
/* 159 */           final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next();
/* 160 */           int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey();
/* 161 */           int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue();
/* 162 */
/* 163 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 164 */
/* 165 */           if (false) {
/* 166 */             throw new RuntimeException("Cannot use null as map key!");
/* 167 */           } else {
/* 168 */             serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 169 */           }
/* 170 */
/* 171 */           if (false) {
/* 172 */             serializefromobject_convertedValues[serializefromobject_index] = null;
/* 173 */           } else {
/* 174 */             serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 175 */           }
/* 176 */
/* 177 */           serializefromobject_index++;
/* 178 */         }
/* 179 */
/* 180 */         serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 181 */       }
/* 182 */       serializefromobject_holder.reset();
/* 183 */
/* 184 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 185 */
/* 186 */       if (mapelements_isNull) {
/* 187 */         serializefromobject_rowWriter.setNullAt(0);
/* 188 */       } else {
/* 189 */         // Remember the current cursor so that we can calculate how many bytes are
/* 190 */         // written later.
/* 191 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 192 */
/* 193 */         if (serializefromobject_value instanceof UnsafeMapData) {
/* 194 */           final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 195 */           // grow the global buffer before writing data.
/* 196 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 197 */           ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 198 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 199 */
/* 200 */         } else {
/* 201 */           final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 202 */           final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 203 */
/* 204 */           // preserve 8 bytes to write the key array numBytes later.
/* 205 */           serializefromobject_holder.grow(8);
/* 206 */           serializefromobject_holder.cursor += 8;
/* 207 */
/* 208 */           // Remember the current cursor so that we can write numBytes of key array later.
/* 209 */           final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 210 */
/* 211 */           if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 212 */             final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 213 */             // grow the global buffer before writing data.
/* 214 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 215 */             ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 216 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 217 */
/* 218 */           } else {
/* 219 */             final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 220 */             serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 221 */
/* 222 */             for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 223 */               if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 224 */                 serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 225 */               } else {
/* 226 */                 final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 227 */                 serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 228 */               }
/* 229 */             }
/* 230 */           }
/* 231 */
/* 232 */           // Write the numBytes of key array into the first 8 bytes.
/* 233 */           Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 234 */
/* 235 */           if (serializefromobject_values instanceof UnsafeArrayData) {
/* 236 */             final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 237 */             // grow the global buffer before writing data.
/* 238 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 239 */             ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 240 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 241 */
/* 242 */           } else {
/* 243 */             final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 244 */             serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 245 */
/* 246 */             for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 247 */               if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 248 */                 serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 249 */               } else {
/* 250 */                 final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 251 */                 serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 252 */               }
/* 253 */             }
/* 254 */           }
/* 255 */
/* 256 */         }
/* 257 */
/* 258 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 259 */       }
/* 260 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 261 */       append(serializefromobject_result);
/* 262 */       if (shouldStop()) return;
/* 263 */     }
/* 264 */   }
/* 265 */ }
```

## How was this patch tested?

```
build/mvn -DskipTests clean package && dev/run-tests
```

Additionally in Spark shell:

```
scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect()
res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4))
```

Author: Michal Senkyr <mike.senkyr@gmail.com>
Author: Michal Šenkýř <mike.senkyr@gmail.com>

Closes #16986 from michalsenkyr/dataset-map-builder.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0538f3b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0538f3b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0538f3b0

Branch: refs/heads/master
Commit: 0538f3b0ae4b80750ab81b210ad6fe77178337bf
Parents: a7c61c1
Author: Michal Senkyr <mike.senkyr@gmail.com>
Authored: Mon Jun 12 08:47:01 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Mon Jun 12 08:47:01 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  33 +---
 .../catalyst/expressions/objects/objects.scala  | 169 ++++++++++++++++++-
 .../sql/catalyst/ScalaReflectionSuite.scala     |  25 +++
 .../org/apache/spark/sql/SQLImplicits.scala     |   5 +
 .../spark/sql/DatasetPrimitiveSuite.scala       |  86 ++++++++++
 5 files changed, 291 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8713053..d580cf4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
@@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection {
         // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
-        val keyData =
-          Invoke(
-            MapObjects(
-              p => deserializerFor(keyType, Some(p), walkedTypePath),
-              Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
-                returnNullable = false),
-              schemaFor(keyType).dataType),
-            "array",
-            ObjectType(classOf[Array[Any]]), returnNullable = false)
-
-        val valueData =
-          Invoke(
-            MapObjects(
-              p => deserializerFor(valueType, Some(p), walkedTypePath),
-              Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
-                returnNullable = false),
-              schemaFor(valueType).dataType),
-            "array",
-            ObjectType(classOf[Array[Any]]), returnNullable = false)
-
-        StaticInvoke(
-          ArrayBasedMapData.getClass,
-          ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
-          "toScalaMap",
-          keyData :: valueData :: Nil)
+        CollectObjectsToMap(
+          p => deserializerFor(keyType, Some(p), walkedTypePath),
+          p => deserializerFor(valueType, Some(p), walkedTypePath),
+          getPath,
+          mirror.runtimeClass(t.typeSymbol.asClass)
+        )
 
       case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
         val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()

http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 1a202ec..79b7b9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
 import org.apache.spark.sql.types._
 
 /**
@@ -652,6 +652,173 @@ case class MapObjects private(
   }
 }
 
+object CollectObjectsToMap {
+  private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+  /**
+   * Construct an instance of CollectObjectsToMap case class.
+   *
+   * @param keyFunction The function applied on the key collection elements.
+   * @param valueFunction The function applied on the value collection elements.
+   * @param inputData An expression that when evaluated returns a map object.
+   * @param collClass The type of the resulting collection.
+   */
+  def apply(
+      keyFunction: Expression => Expression,
+      valueFunction: Expression => Expression,
+      inputData: Expression,
+      collClass: Class[_]): CollectObjectsToMap = {
+    val id = curId.getAndIncrement()
+    val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
+    val mapType = inputData.dataType.asInstanceOf[MapType]
+    val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
+    val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
+    val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
+    val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
+    CollectObjectsToMap(
+      keyLoopValue, keyFunction(keyLoopVar),
+      valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
+      inputData, collClass)
+  }
+}
+
+/**
+ * Expression used to convert a Catalyst Map to an external Scala Map.
+ * The collection is constructed using the associated builder, obtained by calling `newBuilder`
+ * on the collection's companion object.
+ *
+ * @param keyLoopValue the name of the loop variable that is used when iterating over the key
+ *                     collection, and which is used as input for the `keyLambdaFunction`
+ * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
+ *                          a lambda function to handle collection elements.
+ * @param valueLoopValue the name of the loop variable that is used when iterating over the value
+ *                       collection, and which is used as input for the `valueLambdaFunction`
+ * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
+ *                        the value collection, and which is used as input for the
+ *                        `valueLambdaFunction`
+ * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
+ *                            a lambda function to handle collection elements.
+ * @param inputData An expression that when evaluated returns a map object.
+ * @param collClass The type of the resulting collection.
+ */
+case class CollectObjectsToMap private(
+    keyLoopValue: String,
+    keyLambdaFunction: Expression,
+    valueLoopValue: String,
+    valueLoopIsNull: String,
+    valueLambdaFunction: Expression,
+    inputData: Expression,
+    collClass: Class[_]) extends Expression with NonSQLExpression {
+
+  override def nullable: Boolean = inputData.nullable
+
+  override def children: Seq[Expression] =
+    keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override def dataType: DataType = ObjectType(collClass)
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
+    // When we want to apply MapObjects on it, we have to use it.
+    def inputDataType(dataType: DataType) = dataType match {
+      case p: PythonUserDefinedType => p.sqlType
+      case _ => dataType
+    }
+
+    val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
+    val keyElementJavaType = ctx.javaType(mapType.keyType)
+    ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
+    val genKeyFunction = keyLambdaFunction.genCode(ctx)
+    val valueElementJavaType = ctx.javaType(mapType.valueType)
+    ctx.addMutableState("boolean", valueLoopIsNull, "")
+    ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
+    val genValueFunction = valueLambdaFunction.genCode(ctx)
+    val genInputData = inputData.genCode(ctx)
+    val dataLength = ctx.freshName("dataLength")
+    val loopIndex = ctx.freshName("loopIndex")
+    val tupleLoopValue = ctx.freshName("tupleLoopValue")
+    val builderValue = ctx.freshName("builderValue")
+
+    val getLength = s"${genInputData.value}.numElements()"
+
+    val keyArray = ctx.freshName("keyArray")
+    val valueArray = ctx.freshName("valueArray")
+    val getKeyArray =
+      s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
+    val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
+    val getValueArray =
+      s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
+    val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex)
+
+    // Make a copy of the data if it's unsafe-backed
+    def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
+      s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
+    def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) =
+      lambdaFunction.dataType match {
+        case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
+        case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
+        case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
+        case _ => genFunction.value
+      }
+    val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
+    val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
+
+    val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
+
+    val builderClass = classOf[Builder[_, _]].getName
+    val constructBuilder = s"""
+      $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
+      $builderValue.sizeHint($dataLength);
+    """
+
+    val tupleClass = classOf[(_, _)].getName
+    val appendToBuilder = s"""
+      $tupleClass $tupleLoopValue;
+
+      if (${genValueFunction.isNull}) {
+        $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
+      } else {
+        $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
+      }
+
+      $builderValue.$$plus$$eq($tupleLoopValue);
+     """
+    val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
+
+    val code = s"""
+      ${genInputData.code}
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+
+      if (!${genInputData.isNull}) {
+        int $dataLength = $getLength;
+        $constructBuilder
+        $getKeyArray
+        $getValueArray
+
+        int $loopIndex = 0;
+        while ($loopIndex < $dataLength) {
+          $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
+          $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
+          $valueLoopNullCheck
+
+          ${genKeyFunction.code}
+          ${genValueFunction.code}
+
+          $appendToBuilder
+
+          $loopIndex += 1;
+        }
+
+        $getBuilderResult
+      }
+    """
+    ev.copy(code = code, isNull = genInputData.isNull)
+  }
+}
+
 object ExternalMapToCatalyst {
   private val curId = new java.util.concurrent.atomic.AtomicInteger()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 70ad064..ff2414b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite {
     assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
   }
 
+  test("serialize and deserialize arbitrary map types") {
+    val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
+      0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
+    assert(mapSerializer.dataType.head.dataType ==
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val mapDeserializer = deserializerFor[Map[Int, Int]]
+    assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
+
+    import scala.collection.immutable.HashMap
+    val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
+      0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
+    assert(hashMapSerializer.dataType.head.dataType ==
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
+    assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
+
+    import scala.collection.mutable.{LinkedHashMap => LHMap}
+    val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
+      0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
+    assert(linkedHashMapSerializer.dataType.head.dataType ==
+      MapType(LongType, StringType, valueContainsNull = true))
+    val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
+    assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
+  }
+
   private val dataTypeForComplexData = dataTypeFor[ComplexData]
   private val typeOfComplexData = typeOf[ComplexData]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 17671ea..86574e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import scala.collection.Map
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 
@@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
   /** @since 2.2.0 */
   implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
 
+  // Maps
+  /** @since 2.3.0 */
+  implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
   // Arrays
 
   /** @since 1.6.1 */

http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 7e2949a..4126660 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import scala.collection.immutable.Queue
+import scala.collection.mutable.{LinkedHashMap => LHMap}
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.test.SharedSQLContext
@@ -30,8 +31,14 @@ case class ListClass(l: List[Int])
 
 case class QueueClass(q: Queue[Int])
 
+case class MapClass(m: Map[Int, Int])
+
+case class LHMapClass(m: LHMap[Int, Int])
+
 case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
 
+case class ComplexMapClass(map: MapClass, lhmap: LHMapClass)
+
 package object packageobject {
   case class PackageClass(value: Int)
 }
@@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
   }
 
+  test("arbitrary maps") {
+    checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2))
+    checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong))
+    checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble))
+    checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat))
+    checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte))
+    checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort))
+    checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false))
+    checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2"))
+    checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2)))
+    checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2)))
+    checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong))
+
+    checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2))
+    checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong))
+    checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble))
+    checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat))
+    checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte))
+    checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort))
+    checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false))
+    checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2"))
+    checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2)))
+    checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2)))
+    checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
+  }
+
+  ignore("SPARK-19104: map and product combinations") {
+    // Case classes
+    checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2)))
+    checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3))))
+    checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3))
+    checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
+      Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
+    checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3))))
+    checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3))
+    checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
+      LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
+
+    checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2)))
+    checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
+      Map(1 -> LHMapClass(LHMap(2 -> 3))))
+    checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
+      Map(LHMapClass(LHMap(1 -> 2)) -> 3))
+    checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
+      Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
+    checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
+      LHMap(1 -> LHMapClass(LHMap(2 -> 3))))
+    checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
+      LHMap(LHMapClass(LHMap(1 -> 2)) -> 3))
+    checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
+      LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
+
+    val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4)))
+    checkDataset(Seq(complex).toDS(), complex)
+    checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex))
+    checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5))
+    checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex))
+    checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex))
+    checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5))
+    checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex))
+
+    // Tuples
+    checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4))
+    checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4))
+    checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4))
+    checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4))
+    checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(),
+      LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2"))))
+
+    // Complex
+    checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(),
+      LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
+  }
+
   test("nested sequences") {
     checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
     checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
   }
 
+  test("nested maps") {
+    checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3)))
+    checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3))
+  }
+
   test("package objects") {
     import packageobject._
     checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message