arrow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From w...@apache.org
Subject [1/2] arrow git commit: ARROW-655: [C++/Python] Implement DecimalArray
Date Sun, 09 Apr 2017 19:19:59 GMT
Repository: arrow
Updated Branches:
  refs/heads/master 449f99162 -> 754bcce68


http://git-wip-us.apache.org/repos/asf/arrow/blob/754bcce6/python/pyarrow/schema.pyx
----------------------------------------------------------------------
diff --git a/python/pyarrow/schema.pyx b/python/pyarrow/schema.pyx
index 253be45..4b931bf 100644
--- a/python/pyarrow/schema.pyx
+++ b/python/pyarrow/schema.pyx
@@ -29,6 +29,7 @@ from pyarrow.array cimport Array
 from pyarrow.error cimport check_status
 from pyarrow.includes.libarrow cimport (CDataType, CStructType, CListType,
                                         CFixedSizeBinaryType,
+                                        CDecimalType,
                                         TimeUnit_SECOND, TimeUnit_MILLI,
                                         TimeUnit_MICRO, TimeUnit_NANO,
                                         Type, TimeUnit)
@@ -45,7 +46,7 @@ cdef class DataType:
     def __cinit__(self):
         pass
 
-    cdef init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type):
         self.sp_type = type
         self.type = type.get()
 
@@ -66,14 +67,14 @@ cdef class DataType:
 
 cdef class DictionaryType(DataType):
 
-    cdef init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type):
         DataType.init(self, type)
         self.dict_type = <const CDictionaryType*> type.get()
 
 
 cdef class TimestampType(DataType):
 
-    cdef init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type):
         DataType.init(self, type)
         self.ts_type = <const CTimestampType*> type.get()
 
@@ -93,7 +94,7 @@ cdef class TimestampType(DataType):
 
 cdef class FixedSizeBinaryType(DataType):
 
-    cdef init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type):
         DataType.init(self, type)
         self.fixed_size_binary_type = <const CFixedSizeBinaryType*> type.get()
 
@@ -103,6 +104,13 @@ cdef class FixedSizeBinaryType(DataType):
             return self.fixed_size_binary_type.byte_width()
 
 
+cdef class DecimalType(FixedSizeBinaryType):
+
+    cdef void init(self, const shared_ptr[CDataType]& type):
+        DataType.init(self, type)
+        self.decimal_type = <const CDecimalType*> type.get()
+
+
 cdef class Field:
 
     def __cinit__(self):
@@ -354,6 +362,12 @@ def float64():
     return primitive_type(la.Type_DOUBLE)
 
 
+cpdef DataType decimal(int precision, int scale=0):
+    cdef shared_ptr[CDataType] decimal_type
+    decimal_type.reset(new CDecimalType(precision, scale))
+    return box_data_type(decimal_type)
+
+
 def string():
     """
     UTF8 string
@@ -374,11 +388,9 @@ def binary(int length=-1):
     if length == -1:
         return primitive_type(la.Type_BINARY)
 
-    cdef FixedSizeBinaryType out = FixedSizeBinaryType()
     cdef shared_ptr[CDataType] fixed_size_binary_type
     fixed_size_binary_type.reset(new CFixedSizeBinaryType(length))
-    out.init(fixed_size_binary_type)
-    return out
+    return box_data_type(fixed_size_binary_type)
 
 
 def list_(DataType value_type):
@@ -436,6 +448,8 @@ cdef DataType box_data_type(const shared_ptr[CDataType]& type):
         out = TimestampType()
     elif type.get().type == la.Type_FIXED_SIZE_BINARY:
         out = FixedSizeBinaryType()
+    elif type.get().type == la.Type_DECIMAL:
+        out = DecimalType()
     else:
         out = DataType()
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/754bcce6/python/pyarrow/tests/test_convert_builtin.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py
index e2b03d8..d89a8e0 100644
--- a/python/pyarrow/tests/test_convert_builtin.py
+++ b/python/pyarrow/tests/test_convert_builtin.py
@@ -20,6 +20,7 @@ from pyarrow.compat import unittest, u  # noqa
 import pyarrow as pa
 
 import datetime
+import decimal
 
 
 class TestConvertList(unittest.TestCase):
@@ -162,3 +163,42 @@ class TestConvertList(unittest.TestCase):
         data = ['a', 1, 2.0]
         with self.assertRaises(pa.ArrowException):
             pa.from_pylist(data)
+
+    def test_decimal(self):
+        data = [decimal.Decimal('1234.183'), decimal.Decimal('8094.234')]
+        type = pa.decimal(precision=7, scale=3)
+        arr = pa.from_pylist(data, type=type)
+        assert arr.to_pylist() == data
+
+    def test_decimal_different_precisions(self):
+        data = [
+            decimal.Decimal('1234234983.183'), decimal.Decimal('80943244.234')
+        ]
+        type = pa.decimal(precision=13, scale=3)
+        arr = pa.from_pylist(data, type=type)
+        assert arr.to_pylist() == data
+
+    def test_decimal_no_scale(self):
+        data = [decimal.Decimal('1234234983'), decimal.Decimal('8094324')]
+        type = pa.decimal(precision=10)
+        arr = pa.from_pylist(data, type=type)
+        assert arr.to_pylist() == data
+
+    def test_decimal_negative(self):
+        data = [decimal.Decimal('-1234.234983'), decimal.Decimal('-8.094324')]
+        type = pa.decimal(precision=10, scale=6)
+        arr = pa.from_pylist(data, type=type)
+        assert arr.to_pylist() == data
+
+    def test_decimal_no_whole_part(self):
+        data = [decimal.Decimal('-.4234983'), decimal.Decimal('.0103943')]
+        type = pa.decimal(precision=7, scale=7)
+        arr = pa.from_pylist(data, type=type)
+        assert arr.to_pylist() == data
+
+    def test_decimal_large_integer(self):
+        data = [decimal.Decimal('-394029506937548693.42983'),
+                decimal.Decimal('32358695912932.01033')]
+        type = pa.decimal(precision=23, scale=5)
+        arr = pa.from_pylist(data, type=type)
+        assert arr.to_pylist() == data

http://git-wip-us.apache.org/repos/asf/arrow/blob/754bcce6/python/pyarrow/tests/test_convert_pandas.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_convert_pandas.py b/python/pyarrow/tests/test_convert_pandas.py
index 87c9c03..0504e1d 100644
--- a/python/pyarrow/tests/test_convert_pandas.py
+++ b/python/pyarrow/tests/test_convert_pandas.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
 
 import datetime
 import unittest
+import decimal
 
 import numpy as np
 
@@ -451,3 +452,72 @@ class TestPandasConversion(unittest.TestCase):
             self._check_pandas_roundtrip(df)
             self._check_array_roundtrip(col)
             self._check_array_roundtrip(col, mask=strided_mask)
+
+    def test_decimal_32_from_pandas(self):
+        expected = pd.DataFrame({
+            'decimals': [
+                decimal.Decimal('-1234.123'),
+                decimal.Decimal('1234.439'),
+            ]
+        })
+        converted = A.Table.from_pandas(expected)
+        field = A.Field.from_py('decimals', A.decimal(7, 3))
+        schema = A.Schema.from_fields([field])
+        assert converted.schema.equals(schema)
+
+    def test_decimal_32_to_pandas(self):
+        expected = pd.DataFrame({
+            'decimals': [
+                decimal.Decimal('-1234.123'),
+                decimal.Decimal('1234.439'),
+            ]
+        })
+        converted = A.Table.from_pandas(expected)
+        df = converted.to_pandas()
+        tm.assert_frame_equal(df, expected)
+
+    def test_decimal_64_from_pandas(self):
+        expected = pd.DataFrame({
+            'decimals': [
+                decimal.Decimal('-129934.123331'),
+                decimal.Decimal('129534.123731'),
+            ]
+        })
+        converted = A.Table.from_pandas(expected)
+        field = A.Field.from_py('decimals', A.decimal(12, 6))
+        schema = A.Schema.from_fields([field])
+        assert converted.schema.equals(schema)
+
+    def test_decimal_64_to_pandas(self):
+        expected = pd.DataFrame({
+            'decimals': [
+                decimal.Decimal('-129934.123331'),
+                decimal.Decimal('129534.123731'),
+            ]
+        })
+        converted = A.Table.from_pandas(expected)
+        df = converted.to_pandas()
+        tm.assert_frame_equal(df, expected)
+
+    def test_decimal_128_from_pandas(self):
+        expected = pd.DataFrame({
+            'decimals': [
+                decimal.Decimal('394092382910493.12341234678'),
+                -decimal.Decimal('314292388910493.12343437128'),
+            ]
+        })
+        converted = A.Table.from_pandas(expected)
+        field = A.Field.from_py('decimals', A.decimal(26, 11))
+        schema = A.Schema.from_fields([field])
+        assert converted.schema.equals(schema)
+
+    def test_decimal_128_to_pandas(self):
+        expected = pd.DataFrame({
+            'decimals': [
+                decimal.Decimal('394092382910493.12341234678'),
+                -decimal.Decimal('314292388910493.12343437128'),
+            ]
+        })
+        converted = A.Table.from_pandas(expected)
+        df = converted.to_pandas()
+        tm.assert_frame_equal(df, expected)


Mime
View raw message