Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 9D4BA200CB3 for ; Mon, 12 Jun 2017 02:47:13 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 9BC14160BD7; Mon, 12 Jun 2017 00:47:13 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 9AA06160BD8 for ; Mon, 12 Jun 2017 02:47:11 +0200 (CEST) Received: (qmail 15092 invoked by uid 500); 12 Jun 2017 00:47:10 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 15081 invoked by uid 99); 12 Jun 2017 00:47:10 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 12 Jun 2017 00:47:10 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 9C4C4DFF15; Mon, 12 Jun 2017 00:47:10 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 8bit From: wenchen@apache.org To: commits@spark.apache.org Message-Id: <6d75ade69f274ced9680d6caf467e62e@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-18891][SQL] Support for Scala Map collection types Date: Mon, 12 Jun 2017 00:47:10 +0000 (UTC) archived-at: Mon, 12 Jun 2017 00:47:13 -0000 Repository: spark Updated Branches: refs/heads/master a7c61c100 -> 0538f3b0a [SPARK-18891][SQL] Support for Scala Map collection types ## What changes were proposed in this pull request? Add support for arbitrary Scala `Map` types in deserialization as well as a generic implicit encoder. Used the builder approach as in #16541 to construct any provided `Map` type upon deserialization. Please note that this PR also adds (ignored) tests for issue [SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it. Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes: * `java.util.Map`, `java.util.AbstractMap` => `java.util.HashMap` * `java.util.SortedMap`, `java.util.NavigableMap` => `java.util.TreeMap` * `java.util.concurrent.ConcurrentMap` => `java.util.concurrent.ConcurrentHashMap` * `java.util.concurrent.ConcurrentNavigableMap` => `java.util.concurrent.ConcurrentSkipListMap` Resulting codegen for `Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen`: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjectsToMap_loopIsNull1; /* 010 */ private int CollectObjectsToMap_loopValue0; /* 011 */ private boolean CollectObjectsToMap_loopIsNull3; /* 012 */ private int CollectObjectsToMap_loopValue2; /* 013 */ private UnsafeRow deserializetoobject_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 016 */ private scala.collection.immutable.Map mapelements_argValue; /* 017 */ private UnsafeRow mapelements_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 020 */ private UnsafeRow serializefromobject_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ wholestagecodegen_init_0(); /* 034 */ wholestagecodegen_init_1(); /* 035 */ /* 036 */ } /* 037 */ /* 038 */ private void wholestagecodegen_init_0() { /* 039 */ inputadapter_input = inputs[0]; /* 040 */ /* 041 */ deserializetoobject_result = new UnsafeRow(1); /* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 044 */ /* 045 */ mapelements_result = new UnsafeRow(1); /* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 048 */ serializefromobject_result = new UnsafeRow(1); /* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 052 */ /* 053 */ } /* 054 */ /* 055 */ private void wholestagecodegen_init_1() { /* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 057 */ /* 058 */ } /* 059 */ /* 060 */ protected void processNext() throws java.io.IOException { /* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0)); /* 065 */ /* 066 */ boolean deserializetoobject_isNull1 = true; /* 067 */ ArrayData deserializetoobject_value1 = null; /* 068 */ if (!inputadapter_isNull) { /* 069 */ deserializetoobject_isNull1 = false; /* 070 */ if (!deserializetoobject_isNull1) { /* 071 */ Object deserializetoobject_funcResult = null; /* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray(); /* 073 */ if (deserializetoobject_funcResult == null) { /* 074 */ deserializetoobject_isNull1 = true; /* 075 */ } else { /* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult; /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 081 */ } /* 082 */ /* 083 */ boolean deserializetoobject_isNull3 = true; /* 084 */ ArrayData deserializetoobject_value3 = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull3 = false; /* 087 */ if (!deserializetoobject_isNull3) { /* 088 */ Object deserializetoobject_funcResult1 = null; /* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray(); /* 090 */ if (deserializetoobject_funcResult1 == null) { /* 091 */ deserializetoobject_isNull3 = true; /* 092 */ } else { /* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null; /* 098 */ } /* 099 */ scala.collection.immutable.Map deserializetoobject_value = null; /* 100 */ /* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) || /* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) { /* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); /* 104 */ } /* 105 */ /* 106 */ if (!deserializetoobject_isNull1) { /* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) { /* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); /* 109 */ } /* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements(); /* 111 */ /* 112 */ scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder(); /* 113 */ CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength); /* 114 */ /* 115 */ int deserializetoobject_loopIndex = 0; /* 116 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 117 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex)); /* 118 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex)); /* 119 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex); /* 120 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex); /* 121 */ /* 122 */ if (CollectObjectsToMap_loopIsNull1) { /* 123 */ throw new RuntimeException("Found null in map key!"); /* 124 */ } /* 125 */ /* 126 */ scala.Tuple2 CollectObjectsToMap_loopValue4; /* 127 */ /* 128 */ if (CollectObjectsToMap_loopIsNull3) { /* 129 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null); /* 130 */ } else { /* 131 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2); /* 132 */ } /* 133 */ /* 134 */ CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4); /* 135 */ /* 136 */ deserializetoobject_loopIndex += 1; /* 137 */ } /* 138 */ /* 139 */ deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result(); /* 140 */ } /* 141 */ /* 142 */ boolean mapelements_isNull = true; /* 143 */ scala.collection.immutable.Map mapelements_value = null; /* 144 */ if (!false) { /* 145 */ mapelements_argValue = deserializetoobject_value; /* 146 */ /* 147 */ mapelements_isNull = false; /* 148 */ if (!mapelements_isNull) { /* 149 */ Object mapelements_funcResult = null; /* 150 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 151 */ if (mapelements_funcResult == null) { /* 152 */ mapelements_isNull = true; /* 153 */ } else { /* 154 */ mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult; /* 155 */ } /* 156 */ /* 157 */ } /* 158 */ mapelements_isNull = mapelements_value == null; /* 159 */ } /* 160 */ /* 161 */ MapData serializefromobject_value = null; /* 162 */ if (!mapelements_isNull) { /* 163 */ final int serializefromobject_length = mapelements_value.size(); /* 164 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length]; /* 165 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length]; /* 166 */ int serializefromobject_index = 0; /* 167 */ final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator(); /* 168 */ while(serializefromobject_entries.hasNext()) { /* 169 */ final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next(); /* 170 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1(); /* 171 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2(); /* 172 */ /* 173 */ boolean ExternalMapToCatalyst_value_isNull1 = false; /* 174 */ /* 175 */ if (false) { /* 176 */ throw new RuntimeException("Cannot use null as map key!"); /* 177 */ } else { /* 178 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1; /* 179 */ } /* 180 */ /* 181 */ if (false) { /* 182 */ serializefromobject_convertedValues[serializefromobject_index] = null; /* 183 */ } else { /* 184 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1; /* 185 */ } /* 186 */ /* 187 */ serializefromobject_index++; /* 188 */ } /* 189 */ /* 190 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues)); /* 191 */ } /* 192 */ serializefromobject_holder.reset(); /* 193 */ /* 194 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 195 */ /* 196 */ if (mapelements_isNull) { /* 197 */ serializefromobject_rowWriter.setNullAt(0); /* 198 */ } else { /* 199 */ // Remember the current cursor so that we can calculate how many bytes are /* 200 */ // written later. /* 201 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 202 */ /* 203 */ if (serializefromobject_value instanceof UnsafeMapData) { /* 204 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes(); /* 205 */ // grow the global buffer before writing data. /* 206 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 207 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 208 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 209 */ /* 210 */ } else { /* 211 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray(); /* 212 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray(); /* 213 */ /* 214 */ // preserve 8 bytes to write the key array numBytes later. /* 215 */ serializefromobject_holder.grow(8); /* 216 */ serializefromobject_holder.cursor += 8; /* 217 */ /* 218 */ // Remember the current cursor so that we can write numBytes of key array later. /* 219 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor; /* 220 */ /* 221 */ if (serializefromobject_keys instanceof UnsafeArrayData) { /* 222 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes(); /* 223 */ // grow the global buffer before writing data. /* 224 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1); /* 225 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 226 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1; /* 227 */ /* 228 */ } else { /* 229 */ final int serializefromobject_numElements = serializefromobject_keys.numElements(); /* 230 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 231 */ /* 232 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) { /* 233 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) { /* 234 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1); /* 235 */ } else { /* 236 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1); /* 237 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element); /* 238 */ } /* 239 */ } /* 240 */ } /* 241 */ /* 242 */ // Write the numBytes of key array into the first 8 bytes. /* 243 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1); /* 244 */ /* 245 */ if (serializefromobject_values instanceof UnsafeArrayData) { /* 246 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes(); /* 247 */ // grow the global buffer before writing data. /* 248 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2); /* 249 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 250 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2; /* 251 */ /* 252 */ } else { /* 253 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements(); /* 254 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4); /* 255 */ /* 256 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) { /* 257 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) { /* 258 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2); /* 259 */ } else { /* 260 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2); /* 261 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1); /* 262 */ } /* 263 */ } /* 264 */ } /* 265 */ /* 266 */ } /* 267 */ /* 268 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 269 */ } /* 270 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 271 */ append(serializefromobject_result); /* 272 */ if (shouldStop()) return; /* 273 */ } /* 274 */ } /* 275 */ } ``` Codegen for `java.util.Map`: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjectsToMap_loopIsNull1; /* 010 */ private int CollectObjectsToMap_loopValue0; /* 011 */ private boolean CollectObjectsToMap_loopIsNull3; /* 012 */ private int CollectObjectsToMap_loopValue2; /* 013 */ private UnsafeRow deserializetoobject_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 016 */ private java.util.HashMap mapelements_argValue; /* 017 */ private UnsafeRow mapelements_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 020 */ private UnsafeRow serializefromobject_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ wholestagecodegen_init_0(); /* 034 */ wholestagecodegen_init_1(); /* 035 */ /* 036 */ } /* 037 */ /* 038 */ private void wholestagecodegen_init_0() { /* 039 */ inputadapter_input = inputs[0]; /* 040 */ /* 041 */ deserializetoobject_result = new UnsafeRow(1); /* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 044 */ /* 045 */ mapelements_result = new UnsafeRow(1); /* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 048 */ serializefromobject_result = new UnsafeRow(1); /* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 052 */ /* 053 */ } /* 054 */ /* 055 */ private void wholestagecodegen_init_1() { /* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 057 */ /* 058 */ } /* 059 */ /* 060 */ protected void processNext() throws java.io.IOException { /* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0)); /* 065 */ /* 066 */ boolean deserializetoobject_isNull1 = true; /* 067 */ ArrayData deserializetoobject_value1 = null; /* 068 */ if (!inputadapter_isNull) { /* 069 */ deserializetoobject_isNull1 = false; /* 070 */ if (!deserializetoobject_isNull1) { /* 071 */ Object deserializetoobject_funcResult = null; /* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray(); /* 073 */ if (deserializetoobject_funcResult == null) { /* 074 */ deserializetoobject_isNull1 = true; /* 075 */ } else { /* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult; /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 081 */ } /* 082 */ /* 083 */ boolean deserializetoobject_isNull3 = true; /* 084 */ ArrayData deserializetoobject_value3 = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull3 = false; /* 087 */ if (!deserializetoobject_isNull3) { /* 088 */ Object deserializetoobject_funcResult1 = null; /* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray(); /* 090 */ if (deserializetoobject_funcResult1 == null) { /* 091 */ deserializetoobject_isNull3 = true; /* 092 */ } else { /* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null; /* 098 */ } /* 099 */ java.util.HashMap deserializetoobject_value = null; /* 100 */ /* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) || /* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) { /* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); /* 104 */ } /* 105 */ /* 106 */ if (!deserializetoobject_isNull1) { /* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) { /* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); /* 109 */ } /* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements(); /* 111 */ java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength); /* 112 */ /* 113 */ int deserializetoobject_loopIndex = 0; /* 114 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 115 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex)); /* 116 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex)); /* 117 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex); /* 118 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex); /* 119 */ /* 120 */ if (CollectObjectsToMap_loopIsNull1) { /* 121 */ throw new RuntimeException("Found null in map key!"); /* 122 */ } /* 123 */ /* 124 */ CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2); /* 125 */ /* 126 */ deserializetoobject_loopIndex += 1; /* 127 */ } /* 128 */ /* 129 */ deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5; /* 130 */ } /* 131 */ /* 132 */ boolean mapelements_isNull = true; /* 133 */ java.util.HashMap mapelements_value = null; /* 134 */ if (!false) { /* 135 */ mapelements_argValue = deserializetoobject_value; /* 136 */ /* 137 */ mapelements_isNull = false; /* 138 */ if (!mapelements_isNull) { /* 139 */ Object mapelements_funcResult = null; /* 140 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 141 */ if (mapelements_funcResult == null) { /* 142 */ mapelements_isNull = true; /* 143 */ } else { /* 144 */ mapelements_value = (java.util.HashMap) mapelements_funcResult; /* 145 */ } /* 146 */ /* 147 */ } /* 148 */ mapelements_isNull = mapelements_value == null; /* 149 */ } /* 150 */ /* 151 */ MapData serializefromobject_value = null; /* 152 */ if (!mapelements_isNull) { /* 153 */ final int serializefromobject_length = mapelements_value.size(); /* 154 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length]; /* 155 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length]; /* 156 */ int serializefromobject_index = 0; /* 157 */ final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator(); /* 158 */ while(serializefromobject_entries.hasNext()) { /* 159 */ final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next(); /* 160 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey(); /* 161 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue(); /* 162 */ /* 163 */ boolean ExternalMapToCatalyst_value_isNull1 = false; /* 164 */ /* 165 */ if (false) { /* 166 */ throw new RuntimeException("Cannot use null as map key!"); /* 167 */ } else { /* 168 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1; /* 169 */ } /* 170 */ /* 171 */ if (false) { /* 172 */ serializefromobject_convertedValues[serializefromobject_index] = null; /* 173 */ } else { /* 174 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1; /* 175 */ } /* 176 */ /* 177 */ serializefromobject_index++; /* 178 */ } /* 179 */ /* 180 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues)); /* 181 */ } /* 182 */ serializefromobject_holder.reset(); /* 183 */ /* 184 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 185 */ /* 186 */ if (mapelements_isNull) { /* 187 */ serializefromobject_rowWriter.setNullAt(0); /* 188 */ } else { /* 189 */ // Remember the current cursor so that we can calculate how many bytes are /* 190 */ // written later. /* 191 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 192 */ /* 193 */ if (serializefromobject_value instanceof UnsafeMapData) { /* 194 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes(); /* 195 */ // grow the global buffer before writing data. /* 196 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 197 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 198 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 199 */ /* 200 */ } else { /* 201 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray(); /* 202 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray(); /* 203 */ /* 204 */ // preserve 8 bytes to write the key array numBytes later. /* 205 */ serializefromobject_holder.grow(8); /* 206 */ serializefromobject_holder.cursor += 8; /* 207 */ /* 208 */ // Remember the current cursor so that we can write numBytes of key array later. /* 209 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor; /* 210 */ /* 211 */ if (serializefromobject_keys instanceof UnsafeArrayData) { /* 212 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes(); /* 213 */ // grow the global buffer before writing data. /* 214 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1); /* 215 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 216 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1; /* 217 */ /* 218 */ } else { /* 219 */ final int serializefromobject_numElements = serializefromobject_keys.numElements(); /* 220 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 221 */ /* 222 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) { /* 223 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) { /* 224 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1); /* 225 */ } else { /* 226 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1); /* 227 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element); /* 228 */ } /* 229 */ } /* 230 */ } /* 231 */ /* 232 */ // Write the numBytes of key array into the first 8 bytes. /* 233 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1); /* 234 */ /* 235 */ if (serializefromobject_values instanceof UnsafeArrayData) { /* 236 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes(); /* 237 */ // grow the global buffer before writing data. /* 238 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2); /* 239 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 240 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2; /* 241 */ /* 242 */ } else { /* 243 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements(); /* 244 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4); /* 245 */ /* 246 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) { /* 247 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) { /* 248 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2); /* 249 */ } else { /* 250 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2); /* 251 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1); /* 252 */ } /* 253 */ } /* 254 */ } /* 255 */ /* 256 */ } /* 257 */ /* 258 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 259 */ } /* 260 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 261 */ append(serializefromobject_result); /* 262 */ if (shouldStop()) return; /* 263 */ } /* 264 */ } /* 265 */ } ``` ## How was this patch tested? ``` build/mvn -DskipTests clean package && dev/run-tests ``` Additionally in Spark shell: ``` scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect() res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4)) ``` Author: Michal Senkyr Author: Michal Šenkýř Closes #16986 from michalsenkyr/dataset-map-builder. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0538f3b0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0538f3b0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0538f3b0 Branch: refs/heads/master Commit: 0538f3b0ae4b80750ab81b210ad6fe77178337bf Parents: a7c61c1 Author: Michal Senkyr Authored: Mon Jun 12 08:47:01 2017 +0800 Committer: Wenchen Fan Committed: Mon Jun 12 08:47:01 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/ScalaReflection.scala | 33 +--- .../catalyst/expressions/objects/objects.scala | 169 ++++++++++++++++++- .../sql/catalyst/ScalaReflectionSuite.scala | 25 +++ .../org/apache/spark/sql/SQLImplicits.scala | 5 + .../spark/sql/DatasetPrimitiveSuite.scala | 86 ++++++++++ 5 files changed, 291 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8713053..d580cf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - val keyData = - Invoke( - MapObjects( - p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), - returnNullable = false), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - val valueData = - Invoke( - MapObjects( - p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), - returnNullable = false), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[scala.collection.immutable.Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) + CollectObjectsToMap( + p => deserializerFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), + getPath, + mirror.runtimeClass(t.typeSymbol.asClass) + ) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1a202ec..79b7b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -652,6 +652,173 @@ case class MapObjects private( } } +object CollectObjectsToMap { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Construct an instance of CollectObjectsToMap case class. + * + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ + def apply( + keyFunction: Expression => Expression, + valueFunction: Expression => Expression, + inputData: Expression, + collClass: Class[_]): CollectObjectsToMap = { + val id = curId.getAndIncrement() + val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val mapType = inputData.dataType.asInstanceOf[MapType] + val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) + val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" + val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) + CollectObjectsToMap( + keyLoopValue, keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), + inputData, collClass) + } +} + +/** + * Expression used to convert a Catalyst Map to an external Scala Map. + * The collection is constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param keyLoopValue the name of the loop variable that is used when iterating over the key + * collection, and which is used as input for the `keyLambdaFunction` + * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param valueLoopValue the name of the loop variable that is used when iterating over the value + * collection, and which is used as input for the `valueLambdaFunction` + * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over + * the value collection, and which is used as input for the + * `valueLambdaFunction` + * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ +case class CollectObjectsToMap private( + keyLoopValue: String, + keyLambdaFunction: Expression, + valueLoopValue: String, + valueLoopIsNull: String, + valueLambdaFunction: Expression, + inputData: Expression, + collClass: Class[_]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = + keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + def inputDataType(dataType: DataType) = dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => dataType + } + + val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] + val keyElementJavaType = ctx.javaType(mapType.keyType) + ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + val genKeyFunction = keyLambdaFunction.genCode(ctx) + val valueElementJavaType = ctx.javaType(mapType.valueType) + ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + val genValueFunction = valueLambdaFunction.genCode(ctx) + val genInputData = inputData.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val loopIndex = ctx.freshName("loopIndex") + val tupleLoopValue = ctx.freshName("tupleLoopValue") + val builderValue = ctx.freshName("builderValue") + + val getLength = s"${genInputData.value}.numElements()" + + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val getKeyArray = + s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" + val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getValueArray = + s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" + val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = + lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) + val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) + + val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + + val builderClass = classOf[Builder[_, _]].getName + val constructBuilder = s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ + + val tupleClass = classOf[(_, _)].getName + val appendToBuilder = s""" + $tupleClass $tupleLoopValue; + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + """ + val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + + val code = s""" + ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + int $dataLength = $getLength; + $constructBuilder + $getKeyArray + $getValueArray + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $valueLoopNullCheck + + ${genKeyFunction.code} + ${genValueFunction.code} + + $appendToBuilder + + $loopIndex += 1; + } + + $getBuilderResult + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } +} + object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 70ad064..ff2414b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } + test("serialize and deserialize arbitrary map types") { + val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + assert(mapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mapDeserializer = deserializerFor[Map[Int, Int]] + assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) + + import scala.collection.immutable.HashMap + val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + assert(hashMapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) + + import scala.collection.mutable.{LinkedHashMap => LHMap} + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + assert(linkedHashMapSerializer.dataType.head.dataType == + MapType(LongType, StringType, valueContainsNull = true)) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 17671ea..86574e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Maps + /** @since 2.3.0 */ + implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 7e2949a..4126660 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.immutable.Queue +import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.test.SharedSQLContext @@ -30,8 +31,14 @@ case class ListClass(l: List[Int]) case class QueueClass(q: Queue[Int]) +case class MapClass(m: Map[Int, Int]) + +case class LHMapClass(m: LHMap[Int, Int]) + case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) + package object packageobject { case class PackageClass(value: Int) } @@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("arbitrary maps") { + checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) + checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) + checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) + checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) + checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) + checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) + checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) + checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) + checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) + checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) + + checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) + checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) + checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) + checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) + checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + } + + ignore("SPARK-19104: map and product combinations") { + // Case classes + checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) + checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + + checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) + checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + Map(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + + val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) + checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) + checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) + checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) + checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) + checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) + + // Tuples + checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), + LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) + + // Complex + checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), + LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) } + test("nested maps") { + checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) + checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org