tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jroe...@apache.org
Subject [incubator-tvm] 02/02: Refactor anyhow out of the rt layer
Date Sat, 30 May 2020 08:08:16 GMT
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch rust-tvm
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit a44a379bb3b3f4fab505dce3520eeb97f230ac23
Author: Jared Roesch <jroesch@octoml.ai>
AuthorDate: Sat May 30 01:07:46 2020 -0700

    Refactor anyhow out of the rt layer
---
 rust/Cargo.toml                      |  3 +-
 rust/macros/src/object.rs            |  8 ++---
 rust/tvm-rt/src/errors.rs            | 36 ++++++++++++++++----
 rust/tvm-rt/src/function.rs          | 66 +++++++++++++++++++++++++++++++++---
 rust/tvm-rt/src/lib.rs               |  5 +--
 rust/tvm-rt/src/ndarray.rs           | 66 +++++++++++++++++++-----------------
 rust/tvm-rt/src/object/mod.rs        | 17 ++++------
 rust/tvm-rt/src/object/object_ptr.rs | 30 +++++++++-------
 rust/tvm-rt/src/to_function.rs       | 37 ++++++++++----------
 rust/tvm-rt/src/value.rs             |  5 ++-
 rust/tvm/src/ir/array.rs             | 55 +++++++++++++++++-------------
 rust/tvm/src/lib.rs                  | 10 +-----
 rust/tvm/src/transform.rs            |  2 +-
 13 files changed, 211 insertions(+), 129 deletions(-)

diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 6d3481b..e107104 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -29,5 +29,6 @@ members = [
 	"frontend/tests/callback",
 	"frontend/examples/resnet",
     "tvm-sys",
-	"tvm-rt"
+	"tvm-rt",
+	"tvm",
 ]
diff --git a/rust/macros/src/object.rs b/rust/macros/src/object.rs
index 96a86dd..670d326 100644
--- a/rust/macros/src/object.rs
+++ b/rust/macros/src/object.rs
@@ -89,12 +89,12 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
         }
 
         impl std::convert::TryFrom<tvm_rt::RetValue> for #ref_id {
-            type Error = ::anyhow::Error;
+            type Error = tvm_rt::Error;
 
             fn try_from(ret_val: tvm_rt::RetValue) -> Result<#ref_id, Self::Error>
{
                 use std::convert::TryInto;
                 let oref: ObjectRef = ret_val.try_into()?;
-                let ptr = oref.0.ok_or(anyhow::anyhow!("null ptr"))?;
+                let ptr = oref.0.ok_or(tvm_rt::Error::Null)?;
                 let ptr = ptr.downcast::<#payload_id>()?;
                 Ok(#ref_id(Some(ptr)))
             }
@@ -122,7 +122,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
         }
 
         impl<'a> std::convert::TryFrom<tvm_rt::ArgValue<'a>> for #ref_id
{
-            type Error = anyhow::Error;
+            type Error = tvm_rt::Error;
 
             fn try_from(arg_value: tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error>
{
                 use std::convert::TryInto;
@@ -132,7 +132,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
         }
 
         impl<'a> std::convert::TryFrom<&tvm_rt::ArgValue<'a>> for #ref_id
{
-            type Error = anyhow::Error;
+            type Error = tvm_rt::Error;
 
             fn try_from(arg_value: &tvm_rt::ArgValue<'a>) -> Result<#ref_id,
Self::Error> {
                 use std::convert::TryInto;
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 77dbba7..41e873f 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -17,13 +17,10 @@
  * under the License.
  */
 
+use crate::DataType;
 use thiserror::Error;
 
 #[derive(Debug, Error)]
-#[error("Cannot convert from an empty array.")]
-pub struct EmptyArrayError;
-
-#[derive(Debug, Error)]
 #[error("Handle `{name}` is null.")]
 pub struct NullHandleError {
     pub name: String,
@@ -41,5 +38,32 @@ pub struct TypeMismatchError {
 }
 
 #[derive(Debug, Error)]
-#[error("Missing NDArray shape.")]
-pub struct MissingShapeError;
+pub enum NDArrayError {
+    #[error("Missing NDArray shape.")]
+    MissingShape,
+    #[error("Cannot convert from an empty array.")]
+    EmptyArray,
+    #[error("Invalid datatype when attempting to convert ndarray.")]
+    InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError),
+    #[error("a shape error occurred in the Rust ndarray library")]
+    ShapeError(#[from] ndarray::ShapeError),
+    #[error("Expected type `{expected}` but found `{actual}`")]
+    DataTypeMismatch { expected: DataType, actual: DataType }
+}
+
+#[derive(Debug, Error)]
+pub enum Error {
+    #[error("{0}")]
+    Downcast(#[from] tvm_sys::errors::ValueDowncastError),
+    #[error("raw pointer passed across boundary was null")]
+    Null,
+}
+
+impl Error {
+    pub fn downcast(actual_type: String, expected_type: &'static str) -> Error {
+        Self::Downcast(tvm_sys::errors::ValueDowncastError {
+            actual_type,
+            expected_type,
+        })
+    }
+}
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 2a5f446..17f5f6e 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -33,12 +33,14 @@ use std::{
     ptr, slice, str,
     sync::Mutex,
 };
-
+use std::convert::{TryFrom};
 use anyhow::Result;
 use lazy_static::lazy_static;
 
 pub use tvm_sys::{ffi, ArgValue, RetValue};
 
+use crate::errors::Error;
+
 use super::to_boxed_fn::ToBoxedFn;
 use super::to_function::{ToFunction, Typed};
 
@@ -180,6 +182,51 @@ impl Drop for Function {
     }
 }
 
+impl From<Function> for RetValue {
+    fn from(func: Function) -> RetValue {
+        RetValue::FuncHandle(func.handle)
+    }
+}
+
+impl TryFrom<RetValue> for Function {
+    type Error = Error;
+
+    fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> {
+        match ret_value {
+            RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
+            _ => Err(Error::downcast(format!("{:?}", ret_value), "FunctionHandle"))
+        }
+    }
+}
+
+impl<'a> From<Function> for ArgValue<'a> {
+    fn from(func: Function) -> ArgValue<'a> {
+        ArgValue::FuncHandle(func.handle)
+    }
+}
+
+impl<'a> TryFrom<ArgValue<'a>> for Function {
+    type Error = Error;
+
+    fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error>
{
+        match arg_value {
+            ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
+            _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")),
+        }
+    }
+}
+
+impl<'a> TryFrom<&ArgValue<'a>> for Function {
+    type Error = Error;
+
+    fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error>
{
+        match arg_value {
+            ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
+            _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")),
+        }
+    }
+}
+
 /// Registers a Rust function with an arbitrary type signature in
 /// the TVM registry.
 ///
@@ -240,8 +287,8 @@ where
 }
 
 #[macro_export]
-macro_rules! external_func {
-    (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;)
=> {
+macro_rules! external_func_impl {
+    ($name:ident , $($ty_param:tt)* , ( $($arg:ident : $ty:ty),* ), $ret_type:ty, $ext_name:literal)
=> {
         ::paste::item! {
             #[allow(non_upper_case_globals)]
             static [<global_ $name>]: ::once_cell::sync::Lazy<&'static $crate::Function>
=
@@ -251,7 +298,7 @@ macro_rules! external_func {
             });
         }
 
-        pub fn $name($($arg : $ty),*) -> Result<$ret_type, anyhow::Error> {
+        pub fn $name<$($ty_param),*>($($arg : $ty),*) -> anyhow::Result<$ret_type>
w,* {
             let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ $name>]
};
             let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>>
= func_ref.to_boxed_fn();
             let res: $ret_type = func_ref($($arg),*)?;
@@ -260,6 +307,17 @@ macro_rules! external_func {
     }
 }
 
+
+#[macro_export]
+macro_rules! external_func {
+    (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;)
=> {
+        $crate::external_func_impl!($name, , ( $($arg : $ty),* ) , $ret_type, $ext_name);
+    };
+    (fn $name:ident < $($ty_param:ident),* > ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty
as $ext_name:literal;) => {
+        $crate::external_func_impl!($name, $($ty_param:ident),* , ( $($arg : $ty),* ) , $ret_type,
$ext_name);
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs
index 874d4fe..9b64eb6 100644
--- a/rust/tvm-rt/src/lib.rs
+++ b/rust/tvm-rt/src/lib.rs
@@ -44,8 +44,6 @@ use std::{
     str,
 };
 
-use anyhow::Error;
-
 pub use crate::{
     context::{Context, DeviceType},
     errors::*,
@@ -57,7 +55,6 @@ pub use crate::{
 pub use function::{ArgValue, RetValue};
 pub use tvm_sys::byte_array::ByteArray;
 pub use tvm_sys::datatype::DataType;
-
 use tvm_sys::ffi;
 
 // Macro to check the return call to TVM runtime shared library.
@@ -80,7 +77,7 @@ pub fn get_last_error() -> &'static str {
     }
 }
 
-pub(crate) fn set_last_error(err: &Error) {
+pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
     let c_string = CString::new(err.to_string()).unwrap();
     unsafe {
         ffi::TVMAPISetLastError(c_string.as_ptr());
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index 0adae8b..f97b3a4 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -48,16 +48,17 @@
 //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
 
 use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
-
-use crate::errors;
-use anyhow::{bail, ensure, Result};
-use ndarray::{Array, ArrayD};
-use num_traits::Num;
 use std::convert::TryInto;
 use std::ffi::c_void;
+
+use crate::errors::NDArrayError;
+
 use tvm_sys::ffi::DLTensor;
 use tvm_sys::{ffi, ByteArray, Context, DataType};
 
+use ndarray::{Array, ArrayD};
+use num_traits::Num;
+
 /// See the [`module-level documentation`](../ndarray/index.html) for more details.
 ///
 /// Wrapper around TVM array handle.
@@ -146,13 +147,13 @@ impl NDArray {
     }
 
     /// Shows whether the underlying ndarray is contiguous in memory or not.
-    pub fn is_contiguous(&self) -> Result<bool> {
+    pub fn is_contiguous(&self) -> anyhow::Result<bool> {
         Ok(match self.strides() {
             None => true,
             Some(strides) => {
-                // errors::MissingShapeError in case shape is not determined
+                // NDArrayError::MissingShape in case shape is not determined
                 self.shape()
-                    .ok_or(errors::MissingShapeError)?
+                    .ok_or(NDArrayError::MissingShape)?
                     .iter()
                     .zip(strides)
                     .rfold(
@@ -188,16 +189,16 @@ impl NDArray {
     /// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
     /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
     /// ```
-    pub fn to_vec<T>(&self) -> Result<Vec<T>> {
-        ensure!(self.shape().is_some(), errors::EmptyArrayError);
+    pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> {
+        if self.shape().is_some() { return Err(NDArrayError::EmptyArray); }
         let earr = NDArray::empty(
-            self.shape().ok_or(errors::MissingShapeError)?,
+            self.shape().ok_or(NDArrayError::MissingShape)?,
             Context::cpu(0),
             self.dtype(),
         );
         let target = self.copy_to_ndarray(earr)?;
         let arr = target.as_dltensor();
-        let sz = self.size().ok_or(errors::MissingShapeError)?;
+        let sz = self.size().ok_or(NDArrayError::MissingShape)?;
         let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
         unsafe {
             v.as_mut_ptr()
@@ -208,7 +209,7 @@ impl NDArray {
     }
 
     /// Converts the NDArray to [`ByteArray`].
-    pub fn to_bytearray(&self) -> Result<ByteArray> {
+    pub fn to_bytearray(&self) -> Result<ByteArray, NDArrayError> {
         let v = self.to_vec::<u8>()?;
         Ok(ByteArray::from(v))
     }
@@ -238,16 +239,15 @@ impl NDArray {
     }
 
     /// Copies the NDArray to another target NDArray.
-    pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
+    pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, NDArrayError>
{
         if self.dtype() != target.dtype() {
-            bail!(
-                "{}",
-                errors::TypeMismatchError {
-                    expected: self.dtype().to_string(),
-                    actual: target.dtype().to_string(),
-                }
-            );
+            return Err(
+                NDArrayError::DataTypeMismatch {
+                    expected: self.dtype(),
+                    actual: target.dtype()
+                });
         }
+
         check_call!(ffi::TVMArrayCopyFromTo(
             self.as_raw_dltensor(),
             target.as_raw_dltensor(),
@@ -257,9 +257,9 @@ impl NDArray {
     }
 
     /// Copies the NDArray to a target context.
-    pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray> {
+    pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray, NDArrayError>
{
         let tmp = NDArray::empty(
-            self.shape().ok_or(errors::MissingShapeError)?,
+            self.shape().ok_or(NDArrayError::MissingShape)?,
             *target,
             self.dtype(),
         );
@@ -272,7 +272,7 @@ impl NDArray {
         rnd: &ArrayD<T>,
         ctx: Context,
         dtype: DataType,
-    ) -> Result<Self> {
+    ) -> Result<Self, NDArrayError> {
         let shape = rnd.shape().to_vec();
         let mut nd = NDArray::empty(&shape, ctx, dtype);
         let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
@@ -304,24 +304,26 @@ impl NDArray {
 macro_rules! impl_from_ndarray_rustndarray {
     ($type:ty, $type_name:tt) => {
         impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
-            type Error = anyhow::Error;
-            fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
-                ensure!(nd.shape().is_some(), errors::MissingShapeError);
+            type Error = NDArrayError;
+
+            fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error>
{
+                if nd.shape().is_some() { return Err(NDArrayError::MissingShape); }
                 assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch");
                 Ok(Array::from_shape_vec(
-                    &*nd.shape().ok_or(errors::MissingShapeError)?,
+                    &*nd.shape().ok_or(NDArrayError::MissingShape)?,
                     nd.to_vec::<$type>()?,
                 )?)
             }
         }
 
         impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
-            type Error = anyhow::Error;
-            fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
-                ensure!(nd.shape().is_some(), errors::MissingShapeError);
+            type Error = NDArrayError;
+
+            fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error>
{
+                if nd.shape().is_some() { return Err(NDArrayError::MissingShape) };
                 assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch");
                 Ok(Array::from_shape_vec(
-                    &*nd.shape().ok_or(errors::MissingShapeError)?,
+                    &*nd.shape().ok_or(NDArrayError::MissingShape)?,
                     nd.to_vec::<$type>()?,
                 )?)
             }
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 2ff9a1f..32da18e 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -1,7 +1,10 @@
-use crate::external_func;
 use std::convert::TryFrom;
 use std::convert::TryInto;
 use std::ffi::CString;
+
+use crate::external_func;
+use crate::errors::Error;
+
 use tvm_sys::{ArgValue, RetValue};
 
 mod object_ptr;
@@ -27,14 +30,8 @@ impl ToObjectRef for ObjectRef {
     }
 }
 
-// impl<T: ToObjectRef> ToObjectRef for &T {
-//     fn to_object_ref(&self) -> ObjectRef {
-//         (*self).to_object_ref()
-//     }
-// }
-
 impl TryFrom<RetValue> for ObjectRef {
-    type Error = anyhow::Error;
+    type Error = Error;
 
     fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> {
         let optr = ret_val.try_into()?;
@@ -54,7 +51,7 @@ impl From<ObjectRef> for RetValue {
 }
 
 impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef {
-    type Error = anyhow::Error;
+    type Error = Error;
 
     fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error>
{
         let optr = arg_value.try_into()?;
@@ -63,7 +60,7 @@ impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef
{
 }
 
 impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef {
-    type Error = anyhow::Error;
+    type Error = Error;
 
     fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectRef, Self::Error>
{
         // TODO(@jroesch): remove the clone
diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs
index c716c05..8e91878 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -1,10 +1,12 @@
-use anyhow::Context;
 use std::convert::TryFrom;
 use std::ffi::CString;
 use std::ptr::NonNull;
+
 use tvm_sys::ffi::{self, /* TVMObjectFree, */ TVMObjectRetain, TVMObjectTypeKey2Index};
 use tvm_sys::{ArgValue, RetValue};
 
+use crate::errors::Error;
+
 type Deleter<T> = unsafe extern "C" fn(object: *mut T) -> ();
 
 #[derive(Debug)]
@@ -27,6 +29,7 @@ fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool
{
         parent_type_index,
         &mut is_derived
     ));
+
     if is_derived == 0 {
         false
     } else {
@@ -96,7 +99,6 @@ pub struct ObjectPtr<T> {
 
 impl ObjectPtr<Object> {
     fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
-        println!("{:?}", object_ptr);
         let non_null = NonNull::new(object_ptr);
         non_null.map(|ptr| ObjectPtr { ptr })
     }
@@ -144,7 +146,7 @@ impl<T: IsObject> ObjectPtr<T> {
         }
     }
 
-    pub fn downcast<U: IsObject>(&self) -> anyhow::Result<ObjectPtr<U>>
{
+    pub fn downcast<U: IsObject>(&self) -> Result<ObjectPtr<U>, Error>
{
         let child_index = Object::get_type_index::<U>();
         let object_index = self.as_object().type_index;
 
@@ -160,7 +162,7 @@ impl<T: IsObject> ObjectPtr<T> {
                 ptr: self.ptr.cast(),
             })
         } else {
-            Err(anyhow::anyhow!("failed to downcast to object subtype"))
+            Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
         }
     }
 }
@@ -183,16 +185,16 @@ impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue
{
 }
 
 impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
-    type Error = anyhow::Error;
+    type Error = Error;
 
     fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error>
{
         match ret_value {
             RetValue::ObjectHandle(handle) => {
                 let handle: *mut Object = unsafe { std::mem::transmute(handle) };
-                let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?;
+                let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
                 optr.downcast()
             }
-            _ => Err(anyhow::anyhow!("unable to convert the result to an Object")),
+            _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle"))
         }
     }
 }
@@ -207,29 +209,31 @@ impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a>
{
 }
 
 impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
-    type Error = anyhow::Error;
+    type Error = Error;
+
     fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error>
{
         match arg_value {
             ArgValue::ObjectHandle(handle) => {
                 let handle = unsafe { std::mem::transmute(handle) };
-                let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?;
+                let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
                 optr.downcast()
             }
-            _ => Err(anyhow::anyhow!("unable to convert the result to an Object")),
+            _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
         }
     }
 }
 
 impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr<T>
{
-    type Error = anyhow::Error;
+    type Error = Error;
+
     fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error>
{
         match arg_value {
             ArgValue::ObjectHandle(handle) => {
                 let handle = unsafe { std::mem::transmute(handle) };
-                let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?;
+                let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
                 optr.downcast()
             }
-            _ => Err(anyhow::anyhow!("unable to convert the result to an Object")),
+            _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
         }
     }
 }
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index 504ce3e..dac37c8 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -25,19 +25,18 @@
 //!
 //! See the tests and examples repository for more examples.
 
+use std::convert::{TryFrom, TryInto};
 use std::{
     mem::MaybeUninit,
     os::raw::{c_int, c_void},
     ptr, slice,
 };
 
-use anyhow::Result;
+use crate::errors::Error;
+use super::Function;
 
 pub use tvm_sys::{ffi, ArgValue, RetValue};
 
-use super::Function;
-use std::convert::{TryFrom, TryInto};
-
 /// A trait representing whether the function arguments
 /// and return type can be assigned to a TVM packed function.
 ///
@@ -47,7 +46,7 @@ use std::convert::{TryFrom, TryInto};
 ///
 /// And the implementation of it to `ToFunction`.
 pub trait Typed<I, O> {
-    fn args(i: &[ArgValue<'static>]) -> anyhow::Result<I>;
+    fn args(i: &[ArgValue<'static>]) -> Result<I, Error>;
     fn ret(o: O) -> RetValue;
 }
 
@@ -55,7 +54,7 @@ impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>>
for F
 where
     F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>,
 {
-    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<&'a [ArgValue<'static>]>
{
+    fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>],
Error> {
         // this is BAD but just hacking for time being
         Ok(unsafe { std::mem::transmute(args) })
     }
@@ -69,7 +68,7 @@ impl<F, O: Into<RetValue>> Typed<(), O> for F
 where
     F: Fn() -> O,
 {
-    fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<()> {
+    fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<(), Error>
{
         debug_assert!(_args.len() == 0);
         Ok(())
     }
@@ -79,13 +78,13 @@ where
     }
 }
 
-impl<F, A, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,), O>
for F
+impl<F, A, O: Into<RetValue>, E> Typed<(A,), O> for F
 where
     F: Fn(A) -> O,
-    E: std::error::Error + Send + Sync + 'static,
+    Error: From<E>,
     A: TryFrom<ArgValue<'static>, Error = E>,
 {
-    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,)> {
+    fn args(args: &[ArgValue<'static>]) -> Result<(A,), Error> {
         debug_assert!(args.len() == 1);
         let a: A = args[0].clone().try_into()?;
         Ok((a,))
@@ -96,14 +95,14 @@ where
     }
 }
 
-impl<F, A, B, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B),
O> for F
+impl<F, A, B, O: Into<RetValue>, E> Typed<(A, B), O> for F
 where
     F: Fn(A, B) -> O,
-    E: std::error::Error + Send + Sync + 'static,
+    Error: From<E>,
     A: TryFrom<ArgValue<'static>, Error = E>,
     B: TryFrom<ArgValue<'static>, Error = E>,
 {
-    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B)> {
+    fn args(args: &[ArgValue<'static>]) -> Result<(A, B), Error> {
         debug_assert!(args.len() == 2);
         let a: A = args[0].clone().try_into()?;
         let b: B = args[1].clone().try_into()?;
@@ -115,15 +114,15 @@ where
     }
 }
 
-impl<F, A, B, C, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,
B, C), O> for F
+impl<F, A, B, C, O: Into<RetValue>, E> Typed<(A, B, C), O> for F
 where
     F: Fn(A, B, C) -> O,
-    E: std::error::Error + Send + Sync + 'static,
+    Error: From<E>,
     A: TryFrom<ArgValue<'static>, Error = E>,
     B: TryFrom<ArgValue<'static>, Error = E>,
     C: TryFrom<ArgValue<'static>, Error = E>,
 {
-    fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B, C)> {
+    fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C), Error> {
         debug_assert!(args.len() == 3);
         let a: A = args[0].clone().try_into()?;
         let b: B = args[1].clone().try_into()?;
@@ -141,7 +140,7 @@ pub trait ToFunction<I, O>: Sized {
 
     fn into_raw(self) -> *mut Self::Handle;
 
-    fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> anyhow::Result<RetValue>
+    fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue,
Error>
     where
         Self: Typed<I, O>;
 
@@ -280,7 +279,7 @@ where
         Box::into_raw(ptr)
     }
 
-    fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue>
+    fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue,
Error>
     where
         F: Typed<(), O>,
     {
@@ -303,7 +302,7 @@ macro_rules! to_function_instance {
                 Box::into_raw(ptr)
             }
 
-            fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) ->
Result<RetValue> where F: Typed<($($param,)+), O> {
+            fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) ->
Result<RetValue, Error> where F: Typed<($($param,)+), O> {
                 // Ideally we shouldn't need to clone, probably doesn't really matter.
                 let args = F::args(args)?;
                 let out = unsafe {
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
index a9355e0..d9436b1 100644
--- a/rust/tvm-rt/src/value.rs
+++ b/rust/tvm-rt/src/value.rs
@@ -24,10 +24,10 @@
 use std::convert::TryFrom;
 // use std::ffi::c_void;
 
-use crate::{ArgValue, Function, Module, NDArray, RetValue};
+use crate::{ArgValue, Module, NDArray, RetValue};
 use tvm_sys::{
     errors::ValueDowncastError,
-    ffi::{TVMFunctionHandle, TVMModuleHandle},
+    ffi::{TVMModuleHandle},
     try_downcast,
 };
 
@@ -74,7 +74,6 @@ macro_rules! impl_handle_val {
     };
 }
 
-impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
 impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
 
 impl<'a> From<&'a NDArray> for ArgValue<'a> {
diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs
index f371497..a426474 100644
--- a/rust/tvm/src/ir/array.rs
+++ b/rust/tvm/src/ir/array.rs
@@ -1,46 +1,55 @@
-use crate::runtime::function::Builder;
-use crate::runtime::object::{ObjectRef, ToObjectRef};
-use std::convert::{TryFrom, TryInto};
+use std::convert::{TryFrom};
 use std::marker::PhantomData;
-use tvm_sys::TVMRetValue;
+
+use crate::runtime::object::{ObjectRef, ToObjectRef};
+
+use tvm_rt::RetValue;
+use tvm_rt::external_func;
 
 use anyhow::Result;
 
+
 #[derive(Clone)]
 pub struct Array<T: ToObjectRef> {
     object: ObjectRef,
     _data: PhantomData<T>,
 }
 
+external_func! {
+    fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef as "ir.DebugPrint";
+}
+
 impl<T: ToObjectRef> Array<T> {
     pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
-        let iter = data.iter().map(|element| element.to_object_ref());
+        unimplemented!()
+        // let iter = data.iter().map(|element| element.to_object_ref());
 
-        let array_data = Builder::default()
-            .get_function("node.Array")
-            .args(iter)
-            .invoke()?
-            .try_into()?;
+        // let array_data = Builder::default()
+        //     .get_function("node.Array")
+        //     .args(iter)
+        //     .invoke()?
+        //     .try_into()?;
 
-        Ok(Array {
-            object: array_data,
-            _data: PhantomData,
-        })
+        // Ok(Array {
+        //     object: array_data,
+        //     _data: PhantomData,
+        // })
     }
 
     pub fn get(&self, index: isize) -> Result<T>
     where
-        T: TryFrom<TVMRetValue, Error = anyhow::Error>,
+        T: TryFrom<RetValue, Error = anyhow::Error>,
     {
-        // TODO(@jroesch): why do we used a signed index here?
-        let element: T = Builder::default()
-            .get_function("node.ArrayGetItem")
-            .arg(self.object.clone())
-            .arg(index)
-            .invoke()?
-            .try_into()?;
+        unimplemented!()
+        // // TODO(@jroesch): why do we used a signed index here?
+        // let element: T = Builder::default()
+        //     .get_function("node.ArrayGetItem")
+        //     .arg(self.object.clone())
+        //     .arg(index)
+        //     .invoke()?
+        //     .try_into()?;
 
-        Ok(element)
+        // Ok(element)
     }
 }
 // mod array_api {
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index b7cf796..9315f7c 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -31,21 +31,13 @@
 //! Checkout the `examples` repository for more details.
 
 pub use crate::{
-    context::{TVMContext, TVMDeviceType},
     errors::*,
     function::Function,
     module::Module,
     ndarray::NDArray,
 };
 
-// TODO: refactor
-pub use tvm_sys::{
-    errors as common_errors,
-    ffi::{self, DLDataType, TVMByteArray},
-    packed_func::{TVMArgValue, TVMRetValue},
-};
-
-pub type DataType = DLDataType;
+pub use tvm_rt::{Context, DeviceType, DataType};
 
 pub use tvm_rt::context;
 pub use tvm_rt::errors;
diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs
index 3657d3b..a89ab87 100644
--- a/rust/tvm/src/transform.rs
+++ b/rust/tvm/src/transform.rs
@@ -37,5 +37,5 @@ impl PassInfo {
 }
 
 external_func! {
-    fn create_func_pass(func: &Function, pass_info: PassInfo) -> Pass as "relay._transform.MakeFunctionPass";
+    fn create_func_pass(func: Function, pass_info: PassInfo) -> Pass as "relay._transform.MakeFunctionPass";
 }


Mime
View raw message