avro-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From cutt...@apache.org
Subject svn commit: r790264 - in /hadoop/avro/trunk: ./ src/java/org/apache/avro/generic/ src/py/avro/ src/test/py/
Date Wed, 01 Jul 2009 16:53:37 GMT
Author: cutting
Date: Wed Jul  1 16:53:37 2009
New Revision: 790264

URL: http://svn.apache.org/viewvc?rev=790264&view=rev
Log:
AVRO-28.  Add Python support for default values.  Contributed by sharad.

Modified:
    hadoop/avro/trunk/CHANGES.txt
    hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java
    hadoop/avro/trunk/src/py/avro/genericio.py
    hadoop/avro/trunk/src/py/avro/io.py
    hadoop/avro/trunk/src/py/avro/protocol.py
    hadoop/avro/trunk/src/py/avro/reflectio.py
    hadoop/avro/trunk/src/py/avro/reflectipc.py
    hadoop/avro/trunk/src/py/avro/schema.py
    hadoop/avro/trunk/src/test/py/testio.py
    hadoop/avro/trunk/src/test/py/testioreflect.py

Modified: hadoop/avro/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/CHANGES.txt?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/CHANGES.txt (original)
+++ hadoop/avro/trunk/CHANGES.txt Wed Jul  1 16:53:37 2009
@@ -46,6 +46,8 @@
 
     AVRO-67.  Add per-call RPC metadata to spec. (George Porter via cutting)
 
+    AVRO-28. Add Python support for default values. (sharad via cutting)
+
   IMPROVEMENTS
 
     AVRO-11.  Re-implement specific and reflect datum readers and

Modified: hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java (original)
+++ hadoop/avro/trunk/src/java/org/apache/avro/generic/GenericDatumReader.java Wed Jul  1
16:53:37 2009
@@ -89,6 +89,7 @@
       if (branch.getType() == actual.getType())
         switch (branch.getType()) {
         case RECORD:
+        case ENUM:
         case FIXED:
           String name = branch.getName();
           if (name == null || name.equals(actual.getName()))
@@ -101,21 +102,21 @@
     for (Schema branch : expected.getTypes())
       switch (actual.getType()) {
       case INT:
-        switch (expected.getType()) {
+        switch (branch.getType()) {
         case LONG: case FLOAT: case DOUBLE:
-          return expected;
+          return branch;
         }
         break;
       case LONG:
-        switch (expected.getType()) {
+        switch (branch.getType()) {
         case FLOAT: case DOUBLE:
-          return expected;
+          return branch;
         }
         break;
       case FLOAT:
-        switch (expected.getType()) {
+        switch (branch.getType()) {
         case DOUBLE:
-          return expected;
+          return branch;
         }
         break;
       }

Modified: hadoop/avro/trunk/src/py/avro/genericio.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/genericio.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/genericio.py (original)
+++ hadoop/avro/trunk/src/py/avro/genericio.py Wed Jul  1 16:53:37 2009
@@ -52,8 +52,8 @@
 def _validaterecord(schm, object):
   if not isinstance(object, dict):
     return False
-  for field,fieldschema in schm.getfields():
-    if not validate(fieldschema, object.get(field)):
+  for field in schm.getfields().values():
+    if not validate(field.getschema(), object.get(field.getname())):
       return False
   return True
 
@@ -98,72 +98,240 @@
 class DatumReader(io.DatumReaderBase):
   """DatumReader for generic python objects."""
 
-  def __init__(self, schm=None):
-    self.setschema(schm)
+  def __init__(self, actual=None, expected=None):
+    self.setschema(actual)
+    self.__expected = expected
     self.__readfn = {
-     schema.BOOLEAN : lambda schm, decoder: decoder.readboolean(),
-     schema.STRING : lambda schm, decoder: decoder.readutf8(),
-     schema.INT : lambda schm, decoder: decoder.readint(),
-     schema.LONG : lambda schm, decoder: decoder.readlong(),
-     schema.FLOAT : lambda schm, decoder: decoder.readfloat(),
-     schema.DOUBLE : lambda schm, decoder: decoder.readdouble(),
-     schema.BYTES : lambda schm, decoder: decoder.readbytes(),
-     schema.FIXED : lambda schm, decoder: 
-                            (decoder.read(schm.getsize())),
+     schema.BOOLEAN : lambda actual, expected, decoder: decoder.readboolean(),
+     schema.STRING : lambda actual, expected, decoder: decoder.readutf8(),
+     schema.INT : lambda actual, expected, decoder: decoder.readint(),
+     schema.LONG : lambda actual, expected, decoder: decoder.readlong(),
+     schema.FLOAT : lambda actual, expected, decoder: decoder.readfloat(),
+     schema.DOUBLE : lambda actual, expected, decoder: decoder.readdouble(),
+     schema.BYTES : lambda actual, expected, decoder: decoder.readbytes(),
+     schema.FIXED : self.readfixed,
      schema.ARRAY : self.readarray,
      schema.MAP : self.readmap,
      schema.RECORD : self.readrecord,
-     schema.ENUM : self.readenum,
-     schema.UNION : self.readunion
+     schema.ENUM : self.readenum
+     }
+    self.__skipfn = {
+     schema.BOOLEAN : lambda schm, decoder: decoder.skipboolean(),
+     schema.STRING : lambda schm, decoder: decoder.skiputf8(),
+     schema.INT : lambda schm, decoder: decoder.skipint(),
+     schema.LONG : lambda schm, decoder: decoder.skiplong(),
+     schema.FLOAT : lambda schm, decoder: decoder.skipfloat(),
+     schema.DOUBLE : lambda schm, decoder: decoder.skipdouble(),
+     schema.BYTES : lambda schm, decoder: decoder.skipbytes(),
+     schema.FIXED : self.skipfixed,
+     schema.ARRAY : self.skiparray,
+     schema.MAP : self.skipmap,
+     schema.RECORD : self.skiprecord,
+     schema.ENUM : self.skipenum,
+     schema.UNION : self.skipunion
      }
 
   def setschema(self, schm):
-    self.__schm = schm
+    self.__actual = schm
 
   def read(self, decoder):
-    return self.readdata(self.__schm, decoder)
-    
-  def readdata(self, schm, decoder):
-    if schm.gettype() == schema.NULL:
+    if self.__expected is None:
+      self.__expected = self.__actual
+    return self.readdata(self.__actual, self.__expected, decoder)
+
+  def readdata(self, actual, expected, decoder):
+    if actual.gettype() == schema.UNION:
+      actual = actual.getelementtypes()[int(decoder.readlong())]
+    if expected.gettype() == schema.UNION:
+      expected = self._resolve(actual, expected)
+    if actual.gettype() == schema.NULL:
       return None
-    fn = self.__readfn.get(schm.gettype())
+    fn = self.__readfn.get(actual.gettype())
+    if fn is not None:
+      return fn(actual, expected, decoder)
+    else:
+      raise schema.AvroException("Unknown type: "+schema.stringval(actual));
+
+  def skipdata(self, schm, decoder):
+    fn = self.__skipfn.get(schm.gettype())
     if fn is not None:
       return fn(schm, decoder)
     else:
-      raise AvroException("Unknown type: "+schema.stringval(schm));
+      raise schema.AvroException("Unknown type: "+schema.stringval(schm));
 
-  def readmap(self, schm, decoder):
+  def readfixed(self, actual, expected, decoder):
+    self.__checkname(actual, expected)
+    if actual.getsize() != expected.getsize():
+      self.__raisematchException(actual, expected)
+    return decoder.read(actual.getsize())
+
+  def skipfixed(self, schm):
+    return decoder.skip(actual.getsize())
+
+  def readmap(self, actual, expected, decoder):
+    if (actual.getvaluetype().gettype() != 
+          expected.getvaluetype().gettype()):
+      self.__raisematchException(actual, expected)
     result = dict()
     size = decoder.readlong()
     if size != 0:
       for i in range(0, size):
         key = decoder.readutf8()
-        result[key] = self.readdata(schm.getvaluetype(), decoder)
+        result[key] = self.readdata(actual.getvaluetype(), 
+                                    expected.getvaluetype(), decoder)
       decoder.readlong()
     return result
 
-  def readarray(self, schm, decoder):
+  def skipmap(self, schm, decoder):
+    size = decoder.readlong()
+    if size != 0:
+      for i in range(0, size):
+        decoder.skiputf8()
+        self.skipdata(schm.getvaluetype(), decoder)
+      decoder.skiplong()
+
+  def readarray(self, actual, expected, decoder):
+    if (actual.getelementtype().gettype() != 
+          expected.getelementtype().gettype()):
+      self.__raisematchException(actual, expected)
     result = list()
     size = decoder.readlong()
     if size != 0:
       for i in range(0, size):
-        result.append(self.readdata(schm.getelementtype(), decoder))
+        result.append(self.readdata(actual.getelementtype(), 
+                                    expected.getelementtype(), decoder))
       decoder.readlong()
     return result
 
-  def readrecord(self, schm, decoder):
-    result = dict() 
-    for field,fieldschema in schm.getfields():
-      result[field] = self.readdata(fieldschema, decoder)
-    return result
+  def skiparray(self, schm, decoder):
+    size = decoder.readlong()
+    if size != 0:
+      for i in range(0, size):
+        self.skipdata(schm.getelementtype(), decoder)
+      decoder.skiplong()
+
+  def createrecord(self, schm):
+    return dict()
+
+  def addfield(self, record, name, value):
+     record[name] = value
 
-  def readenum(self, schm, decoder):
+  def readrecord(self, actual, expected, decoder):
+    self.__checkname(actual, expected)
+    expectedfields = expected.getfields()
+    record = self.createrecord(actual)
+    size = 0 
+    for fieldname, field in actual.getfields().items():
+      if expected == actual:
+        expectedfield = field
+      else:
+        expectedfield = expectedfields.get(fieldname)
+      if expectedfield is None:
+        self.skipdata(field.getschema(), decoder)
+        continue
+      self.addfield(record, fieldname, self.readdata(field.getschema(), 
+                                        expectedfield.getschema(), decoder))
+      size += 1
+    if len(expectedfields) > size:  # not all fields set
+      actualfields = actual.getfields()
+      for fieldname, field in expectedfields.items():
+        if not actualfields.has_key(fieldname):
+          defval = field.getdefaultvalue()
+          if defval is not None:
+            self.addfield(record, fieldname, 
+                      self._defaultfieldvalue(field.getschema(), defval))
+    return record
+
+  def skiprecord(self, schm, decoder):
+    for field in schm.getfields().values():
+      self.skipdata(field.getschema(), decoder)
+
+  def readenum(self, actual, expected, decoder):
+    self.__checkname(actual, expected)
     index = decoder.readint()
-    return schm.getenumsymbols()[index]
+    return actual.getenumsymbols()[index]
+
+  def skipenum(self, schm, decoder):
+    decoder.skipint()
 
-  def readunion(self, schm, decoder):
+  def skipunion(self, schm, decoder):
     index = int(decoder.readlong())
-    return self.readdata(schm.getelementtypes()[index], decoder)
+    return self.skipdata(schm.getelementtypes()[index], decoder)
+
+  def _resolve(self, actual, expected):
+    # scan for exact match
+    for branch in expected.getelementtypes():
+      if branch.gettype() == actual.gettype():
+        return branch
+    #scan for match via numeric promotion
+    for branch in expected.getelementtypes():
+      actualtype = actual.gettype()
+      expectedtype = branch.gettype()
+      if actualtype == schema.INT:
+        if (expectedtype == schema.LONG or expectedtype == schema.FLOAT 
+            or expectedtype == schema.DOUBLE):
+          return branch
+      elif actualtype == schema.LONG:
+        if (expectedtype == schema.FLOAT or expectedtype == schema.DOUBLE):
+          return branch
+      elif actualtype == schema.FLOAT:
+        if (expectedtype == schema.DOUBLE):
+          return branch
+    self.__raisematchException(actual, expected)
+
+  def __checkname(self, actual, expected):
+    if actual.getname() != expected.getname():
+      self.__raisematchException(actual, expected)
+
+  def __raisematchException(self, actual, expected):
+    raise schema.AvroException("Expected "+schema.stringval(expected)+
+                        ", found "+schema.stringval(actual))
+
+  def _defaultfieldvalue(self, schm, defaultnode):
+    if schm.gettype() == schema.RECORD:
+      record = self.createrecord(schm)
+      for field in schm.getfields().values():
+        v = defaultnode.get(field.getname())
+        if v is None:
+          v = field.getdefaultvalue()
+        if v is not None:
+          record[field.getname()] = self._defaultfieldvalue(
+                                                  field.getschema(), v)
+      return record
+    elif schm.gettype() == schema.ENUM:
+      return defaultnode
+    elif schm.gettype() == schema.ARRAY:
+      array = list()
+      for node in defaultnode:
+        array.append(self._defaultfieldvalue(schm.getelementtype(), node))
+      return array
+    elif schm.gettype() == schema.MAP:
+      map = dict()
+      for k,v in defaultnode.items():
+        map[k] = self._defaultfieldvalue(schm.getvaluetype(), v)
+      return map
+    elif schm.gettype() == schema.UNION:
+      return self._defaultfieldvalue(schm.getelementtypes()[0], defaultnode)
+    elif schm.gettype() == schema.FIXED:
+      return defaultnode
+    elif schm.gettype() == schema.STRING:
+      return defaultnode
+    elif schm.gettype() == schema.BYTES:
+      return defaultnode
+    elif schm.gettype() == schema.INT:
+      return int(defaultnode)
+    elif schm.gettype() == schema.LONG:
+      return long(defaultnode)
+    elif schm.gettype() == schema.FLOAT:
+      return float(defaultnode)
+    elif schm.gettype() == schema.DOUBLE:
+      return float(defaultnode)
+    elif schm.gettype() == schema.BOOLEAN:
+      return bool(defaultnode)
+    elif schm.gettype() == schema.NULL:
+      return None
+    else:
+      raise schema.AvroException("Unknown type: "+schema.stringval(actual))
 
 class DatumWriter(io.DatumWriterBase):
   """DatumWriter for generic python objects."""
@@ -233,8 +401,8 @@
   def writerecord(self, schm, datum, encoder):
     if not isinstance(datum, dict):
       raise io.AvroTypeException(schm, datum)
-    for field,fieldschema in schm.getfields():
-      self.writedata(fieldschema, datum.get(field), encoder)
+    for field in schm.getfields().values():
+      self.writedata(field.getschema(), datum.get(field.getname()), encoder)
 
   def writeunion(self, schm, datum, encoder):
     index = self.resolveunion(schm, datum)

Modified: hadoop/avro/trunk/src/py/avro/io.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/io.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/io.py (original)
+++ hadoop/avro/trunk/src/py/avro/io.py Wed Jul  1 16:53:37 2009
@@ -108,6 +108,30 @@
   def read(self, len):
     return struct.unpack(len.__str__()+'s', self.__reader.read(len))[0]
 
+  def skipboolean(self):
+    self.skip(1)
+
+  def skipint(self):
+    self.skip(4)
+
+  def skiplong(self):
+    self.skip(8)
+
+  def skipfloat(self):
+    self.skip(4)
+
+  def skipdouble(self):
+    self.skip(8)
+
+  def skipbytes(self):
+    self.skip(self.readlong())
+
+  def skiputf8(self):
+    self.skipbytes()
+
+  def skip(self, len):
+    self.__reader.seek(self.__reader.tell()+len)
+
 class Encoder(object):
   """Write leaf values."""
 

Modified: hadoop/avro/trunk/src/py/avro/protocol.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/protocol.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/protocol.py (original)
+++ hadoop/avro/trunk/src/py/avro/protocol.py Wed Jul  1 16:53:37 2009
@@ -76,11 +76,11 @@
       str = cStringIO.StringIO()
       str.write("{\"request\": [")
       count = 0
-      for k,v in self.__request.getfields():
+      for field in self.__request.getfields().values():
         str.write("{\"name\": \"")
-        str.write(k)
+        str.write(field.getname())
         str.write("\", \"type\": ")
-        str.write(v.str(self.__proto.gettypes()))
+        str.write(field.getschema().str(self.__proto.gettypes()))
         str.write("}")
         count+=1
         if count < len(self.__request.getfields()):
@@ -158,8 +158,9 @@
       fieldtype = field.get("type")
       if fieldtype is None:
         raise SchemaParseException("No param type: "+field.__str__())
-      fields[fieldname] = schema._parse(fieldtype, self.__types)
-    request = schema._RecordSchema(list(fields.iteritems()))
+      fields[fieldname] = schema.Field(fieldname, 
+                                       schema._parse(fieldtype, self.__types))
+    request = schema._RecordSchema(fields)
     response = schema._parse(res, self.__types)
 
     erorrs = list()

Modified: hadoop/avro/trunk/src/py/avro/reflectio.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/reflectio.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/reflectio.py (original)
+++ hadoop/avro/trunk/src/py/avro/reflectio.py Wed Jul  1 16:53:37 2009
@@ -43,9 +43,9 @@
 def _validaterecord(schm, pkgname, object):
   if not isinstance(object, gettype(schm, pkgname)):
     return False
-  for field,fieldschema in schm.getfields():
-    data = object.__getattribute__(field)
-    if not validate(fieldschema, pkgname, data):
+  for field in schm.getfields().values():
+    data = object.__getattribute__(field.getname())
+    if not validate(field.getschema(), pkgname, data):
       return False
   return True
 
@@ -94,8 +94,8 @@
   clazz = globals().get(clazzname)
   if clazz is None:
     clazz = type(str(clazzname),(base,),{})
-    for field,fieldschema in recordschm.getfields():
-      setattr(clazz, field, None)
+    for field in recordschm.getfields().values():
+      setattr(clazz, field.getname(), None)
     globals()[clazzname] = clazz
   return clazz
 
@@ -106,12 +106,12 @@
     genericio.DatumReader.__init__(self, schm)
     self.__pkgname = pkgname
 
-  def readrecord(self, schm, decoder):
+  def addfield(self, record, name, value):
+    setattr(record, name, value)
+
+  def createrecord(self, schm):
     type = gettype(schm, self.__pkgname)
-    result = type()
-    for field,fieldschema in schm.getfields():
-      setattr(result, field, self.readdata(fieldschema, decoder))
-    return result
+    return type()
 
 class ReflectDatumWriter(genericio.DatumWriter):
   """DatumWriter for arbitrary python classes."""
@@ -121,8 +121,9 @@
     self.__pkgname = pkgname
 
   def writerecord(self, schm, datum, encoder):
-    for field,fieldschema in schm.getfields():
-      self.writedata(fieldschema, getattr(datum, field), encoder)
+    for field in schm.getfields().values():
+      self.writedata(field.getschema(), getattr(datum, field.getname()),
+                      encoder)
 
   def resolveunion(self, schm, datum):
     index = 0

Modified: hadoop/avro/trunk/src/py/avro/reflectipc.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/reflectipc.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/reflectipc.py (original)
+++ hadoop/avro/trunk/src/py/avro/reflectipc.py Wed Jul  1 16:53:37 2009
@@ -68,9 +68,8 @@
     return reflectio.ReflectDatumReader(self.__pkgname, schm)
 
   def writerequest(self, schm, req, encoder):
-    index = 0
     for arg in req:
-      argschm = schm.getfields()[index][1]
+      argschm = schm.getfields().values()[0].getschema()
       genericipc.Requestor.writerequest(self, argschm, arg, encoder)
 
   def readerror(self, schm, decoder):
@@ -91,8 +90,9 @@
 
   def readrequest(self, schm, decoder):
     req = list()
-    for field, fieldschm in schm.getfields():
-      req.append(genericipc.Responder.readrequest(self, fieldschm, decoder))
+    for field in schm.getfields().values():
+      req.append(genericipc.Responder.readrequest(self, field.getschema(), 
+                                               decoder))
     return req
 
   def writeerror(self, schm, error, encoder):

Modified: hadoop/avro/trunk/src/py/avro/schema.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/schema.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/schema.py (original)
+++ hadoop/avro/trunk/src/py/avro/schema.py Wed Jul  1 16:53:37 2009
@@ -142,6 +142,25 @@
       hash += self.__space.__hash__()
     return hash
 
+class Field(object):
+  def __init__(self, name, schema, defaultvalue=None):
+    self.__name = name
+    self.__schema = schema
+    self.__defaultvalue = defaultvalue
+
+  def getname(self):
+    return self.__name
+
+  def getschema(self):
+    return self.__schema
+
+  def getdefaultvalue(self):
+    return self.__defaultvalue
+
+  def __eq__(self, other, seen={}):
+    return (self.__name == other.__name and
+            self.__schema.__eq__(other.__schema, seen) and 
+            self.__defaultvalue == other.__defaultvalue)
 
 class _RecordSchema(NamedSchema):
   def __init__(self, fields, name=None, space=None, iserror=False):
@@ -170,11 +189,14 @@
     str.write(self.namestring())
     str.write("\"fields\": [")
     count=0
-    for k,v in self.__fields:
+    for field in self.__fields.values():
       str.write("{\"name\": \"")
-      str.write(k)
+      str.write(field.getname())
       str.write("\", \"type\": ")
-      str.write(v.str(names))
+      str.write(field.getschema().str(names))
+      if field.getdefaultvalue() is not None:
+        str.write(", \"default\": ")
+        str.write(repr(field.getdefaultvalue()))
       str.write("}")
       count+=1
       if count < len(self.__fields):
@@ -190,8 +212,8 @@
       if len(other.__fields) != size:
         return False
       seen[id(self)] = other
-      for i in range(0, size):
-        if not self.__fields[i][1].__eq__(other.__fields[i][1], seen):
+      for field in self.__fields.values():
+        if not field.__eq__(other.__fields.get(field.getname()), seen):
           return False
       return True
     else:
@@ -202,8 +224,8 @@
       return 0
     seen.add(id(self))
     hash = NamedSchema.__hash__(self, seen)
-    for field, fieldschm in self.__fields:
-      hash = hash + fieldschm.__hash__(seen)
+    for field in self.__fields.values():
+      hash = hash + field.getschema().__hash__(seen)
     return hash
 
 class _ArraySchema(Schema):
@@ -444,7 +466,7 @@
       if name is None:
         raise SchemaParseException("No name in schema: "+obj.__str__())
       if type == "record" or type == "error":
-        fields = list()
+        fields = odict.OrderedDict()
         schema = _RecordSchema(fields, name, space, type == "error")
         names[name] = schema
         fieldsnode = obj.get("fields")
@@ -457,7 +479,9 @@
           fieldtype = field.get("type")
           if fieldtype is None:
             raise SchemaParseException("No field type: "+field.__str__())
-          fields.append((fieldname, _parse(fieldtype, names)))
+          defaultval = field.get("default")
+          fields[fieldname] = Field(fieldname, _parse(fieldtype, names), 
+                                    defaultval)
         return schema
       elif type == "enum":
         symbolsnode = obj.get("symbols")

Modified: hadoop/avro/trunk/src/test/py/testio.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/testio.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/test/py/testio.py (original)
+++ hadoop/avro/trunk/src/test/py/testio.py Wed Jul  1 16:53:37 2009
@@ -75,8 +75,8 @@
       return map
     elif schm.gettype() == schema.RECORD:
       m = dict()
-      for field, fieldschm in schm.getfields():
-        m[field] = self.nextdata(fieldschm, d+1)
+      for field in schm.getfields().values():
+        m[field.getname()] = self.nextdata(field.getschema(), d+1)
       return m
     elif schm.gettype() == schema.UNION:
       datum = self.nextdata(random.choice(schm.getelementtypes()), d)
@@ -107,43 +107,45 @@
     self.__assertdata = assertdata
 
   def testNull(self):
-    self.check("\"null\"")
+    self.checkdefault("\"null\"", "null", None)
 
   def testBoolean(self):
-    self.check("\"boolean\"")
+    self.checkdefault("\"boolean\"", "true", True)
 
   def testString(self):
-    self.check("\"string\"")
+    self.checkdefault("\"string\"", "\"foo\"", "foo")
 
   def testBytes(self):
-    self.check("\"bytes\"")
+    self.checkdefault("\"bytes\"", "\"foo\"", "foo")
 
   def testInt(self):
-    self.check("\"int\"")
+    self.checkdefault("\"int\"", "5", 5)
 
   def testLong(self):
-    self.check("\"long\"")
+    self.checkdefault("\"long\"", "9", 9)
 
   def testFloat(self):
-    self.check("\"float\"")
+    self.checkdefault("\"float\"", "1.2", float(1.2))
 
   def testDouble(self):
-    self.check("\"double\"")
+    self.checkdefault("\"double\"", "1.2", float(1.2))
 
   def testArray(self):
-    self.check("{\"type\":\"array\", \"items\": \"long\"}")
+    self.checkdefault("{\"type\":\"array\", \"items\": \"long\"}",
+                       "[1]", [1])
 
   def testMap(self):
-    self.check("{\"type\":\"map\", \"values\": \"string\"}")
+    self.checkdefault("{\"type\":\"map\", \"values\": \"long\"}",
+                      "{\"a\":1}", {unicode("a"):1})
 
   def testRecord(self):
-    self.check("{\"type\":\"record\", \"name\":\"Test\"," +
+    self.checkdefault("{\"type\":\"record\", \"name\":\"Test\"," +
                "\"fields\":[{\"name\":\"f\", \"type\":" +
-               "\"string\"}, {\"name\":\"fb\", \"type\":\"bytes\"}]}")
+               "\"long\"}]}", "{\"f\":11}", {"f" : 11})
 
   def testEnum(self):
-    self.check("{\"type\": \"enum\", \"name\":\"Test\","+
-               "\"symbols\": [\"A\", \"B\"]}")
+    self.checkdefault("{\"type\": \"enum\", \"name\":\"Test\","+
+               "\"symbols\": [\"A\", \"B\"]}", "\"B\"", "B")
 
   def testRecursive(self):
     self.check("{\"type\": \"record\", \"name\": \"Node\", \"fields\": ["
@@ -163,9 +165,11 @@
       +"{\"type\": \"record\", \"name\": \"Cons\", \"fields\": ["
       +"{\"name\":\"car\", \"type\":\"string\"}," 
       +"{\"name\":\"cdr\", \"type\":\"string\"}]}]")
+    self.checkdefault("[\"double\", \"long\"]", "1.1", 1.1)
 
   def testFixed(self):
-    self.check("{\"type\": \"fixed\", \"name\":\"Test\", \"size\": 1}") 
+    self.checkdefault("{\"type\": \"fixed\", \"name\":\"Test\", \"size\": 1}", 
+                      "\"a\"", "a") 
 
   def check(self, string):
     schm = schema.parse(string)
@@ -180,6 +184,20 @@
       self.checkser(schm, randomdata)
     self.checkdatafile(schm)
 
+  def checkdefault(self, schemajson, defaultjson, defaultvalue):
+    self.check(schemajson)
+    actual = schema.parse("{\"type\":\"record\", \"name\":\"Foo\","
+                          + "\"fields\":[]}")
+    expected = schema.parse("{\"type\":\"record\", \"name\":\"Foo\"," 
+                             +"\"fields\":[{\"name\":\"f\", "
+                             +"\"type\":"+schemajson+", "
+                             +"\"default\":"+defaultjson+"}]}")
+    reader = genericio.DatumReader(actual, expected)
+    record = reader.read(io.Decoder(cStringIO.StringIO()))
+    self.assertEquals(defaultvalue, record.get("f"))
+    #FIXME fix to string for default values
+    #self.assertEquals(expected, schema.parse(schema.stringval(expected)))
+
   def checkser(self, schm, randomdata):
     datum = randomdata.next()
     self.assertTrue(self.__validator(schm, datum))

Modified: hadoop/avro/trunk/src/test/py/testioreflect.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/testioreflect.py?rev=790264&r1=790263&r2=790264&view=diff
==============================================================================
--- hadoop/avro/trunk/src/test/py/testioreflect.py (original)
+++ hadoop/avro/trunk/src/test/py/testioreflect.py Wed Jul  1 16:53:37 2009
@@ -29,8 +29,8 @@
     if schm.gettype() == schema.RECORD:
       clazz = reflectio.gettype(schm, _PKGNAME)
       result = clazz()
-      for field,fieldschema in schm.getfields():
-        result.__setattr__(field, self.nextdata(fieldschema,d))
+      for field in schm.getfields().values():
+        result.__setattr__(field.getname(), self.nextdata(field.getschema(),d))
       return result
     else:
       return testio.RandomData.nextdata(self, schm, d)



Mime
View raw message