tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jroe...@apache.org
Subject [incubator-tvm] 02/04: Convert external macro to procmacro
Date Mon, 08 Jun 2020 21:01:24 GMT
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch rust-tvm
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 62b6d1bac749e53d40fd999955b496feec486079
Author: Jared Roesch <jroesch@octoml.ai>
AuthorDate: Mon Jun 8 13:42:32 2020 -0700

    Convert external macro to procmacro
---
 rust/macros/src/external.rs    | 163 +++++++++++++++++++++++++++++++++++++++++
 rust/macros/src/lib.rs         |   1 +
 rust/macros/src/object.rs      |  58 +++++++--------
 rust/macros/src/util.rs        |  11 +++
 rust/tvm-rt/Cargo.toml         |   1 -
 rust/tvm-rt/src/context.rs     |   5 +-
 rust/tvm-rt/src/errors.rs      |   2 +
 rust/tvm-rt/src/function.rs    |  11 +--
 rust/tvm-rt/src/lib.rs         |   2 +
 rust/tvm-rt/src/module.rs      |  29 ++++----
 rust/tvm-rt/src/object/mod.rs  |  14 ++--
 rust/tvm-rt/src/string.rs      |   5 +-
 rust/tvm-rt/src/to_boxed_fn.rs |  36 ++++-----
 13 files changed, 255 insertions(+), 83 deletions(-)

diff --git a/rust/macros/src/external.rs b/rust/macros/src/external.rs
new file mode 100644
index 0000000..989cc6a
--- /dev/null
+++ b/rust/macros/src/external.rs
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use std::env;
+use quote::quote;
+use proc_macro2::Span;
+use syn::parse::{Parse, ParseStream, Result};
+
+use syn::{Ident, Meta, FnArg, Generics, TraitItemMethod, Lit, NestedMeta, Type, ReturnType,
Pat};
+
+struct External {
+    tvm_name: String,
+    ident: Ident,
+    generics: Generics,
+    inputs: Vec<FnArg>,
+    ret_type: ReturnType,
+}
+
+impl Parse for External {
+    fn parse(input: ParseStream) -> Result<Self> {
+        let method: TraitItemMethod = input.parse()?;
+        assert_eq!(method.attrs.len(), 1);
+        let sig = method.sig;
+        let tvm_name = method.attrs[0].parse_meta()?;
+        let tvm_name = match tvm_name {
+            Meta::List(meta_list) => {
+                let name = meta_list.path.get_ident()
+                    .expect("name");
+                assert_eq!(name.to_string(), "name".to_string());
+                match meta_list.nested.first() {
+                    Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(),
+                    _ => panic!(),
+                }
+            }
+            _ => panic!()
+        };
+        assert_eq!(method.default, None);
+        assert!(method.semi_token != None);
+        let ident = sig.ident;
+        let generics = sig.generics;
+        let inputs = sig.inputs.iter().map(|param| param.clone()).collect();
+        let ret_type = sig.output;
+
+        Ok(External {
+            tvm_name,
+            ident,
+            generics,
+            inputs,
+            ret_type,
+        })
+    }
+}
+
+struct ExternalInput {
+    externs: Vec<External>,
+}
+
+impl Parse for ExternalInput {
+    fn parse(input: ParseStream) -> Result<Self> {
+        let mut externs: Vec<External> = Vec::new();
+
+        loop {
+            if input.is_empty() { break; }
+            externs.push(input.parse()?);
+        }
+
+        Ok(ExternalInput { externs })
+    }
+}
+
+ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    let ext_input = syn::parse_macro_input!(input as ExternalInput);
+
+    let tvm_rt_crate = if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" {
+        quote!( crate )
+    } else {
+        quote!( tvm_rt )
+    };
+
+    let err_type = quote! { #tvm_rt_crate::Error };
+
+    let mut items = Vec::new();
+
+    for external in &ext_input.externs {
+        let name = &external.ident;
+        let global_name = format!("global_{}", external.ident);
+        let global_name = Ident::new(&global_name, Span::call_site());
+        let ext_name = &external.tvm_name;
+
+        let ty_params: Vec<syn::TypeParam> =
+            external.generics.params.iter().map(|ty_param|
+                match ty_param {
+                    syn::GenericParam::Type(param) => param.clone(),
+                    _ => panic!()
+                }).collect();
+
+        let args = &external.inputs;
+
+        let (args, tys): (Vec<Ident>, Vec<Type>) =
+            args.iter().map(|arg| {
+                match arg {
+                    FnArg::Typed(pat_type) =>  {
+                        match &*pat_type.pat {
+                            Pat::Ident(pat_ident) => {
+                                let ident: Ident = pat_ident.ident.clone();
+                                let ty: Type = *pat_type.ty.clone();
+                                (ident, ty)
+                            },
+                            _ => panic!()
+                        }
+                    },
+                    _ => panic!(),
+                }
+            }).unzip();
+
+        let ret_type = match &external.ret_type {
+            ReturnType::Type(_, rtype) => *rtype.clone(),
+            _ => panic!()
+        };
+
+        let global = quote! {
+            #[allow(non_upper_case_globals)]
+            static #global_name: ::once_cell::sync::Lazy<&'static #tvm_rt_crate::Function>
=
+            ::once_cell::sync::Lazy::new(|| {
+                #tvm_rt_crate::Function::get(#ext_name)
+                .expect(concat!("unable to load external function", stringify!(#ext_name),
"from TVM registry."))
+            });
+        };
+
+        items.push(global);
+
+        let wrapper = quote! {
+            pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type,
#err_type> {
+                let func_ref: &#tvm_rt_crate::Function = &#global_name;
+                let func_ref: Box<dyn Fn(#(#tys),*) -> Result<#ret_type, #err_type>>
= func_ref.to_boxed_fn();
+                let res: #ret_type = func_ref(#(#args),*)?;
+                Ok(res)
+            }
+        };
+
+        items.push(wrapper);
+    }
+
+    proc_macro::TokenStream::from(quote! {
+        #(#items
+        )*
+    })
+ }
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
index d0ac1ca..603e1ce 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/lib.rs
@@ -22,6 +22,7 @@ use proc_macro::TokenStream;
 mod external;
 mod import_module;
 mod object;
+mod util;
 
 #[proc_macro]
 pub fn import_module(input: TokenStream) -> TokenStream {
diff --git a/rust/macros/src/object.rs b/rust/macros/src/object.rs
index 670d326..bee22c3 100644
--- a/rust/macros/src/object.rs
+++ b/rust/macros/src/object.rs
@@ -23,7 +23,10 @@ use quote::quote;
 use syn::DeriveInput;
 use syn::Ident;
 
+use crate::util::get_tvm_rt_crate;
+
 pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
+    let tvm_rt_crate = get_tvm_rt_crate();
     let derive_input = syn::parse_macro_input!(input as DeriveInput);
     let payload_id = derive_input.ident;
 
@@ -63,7 +66,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
     let base = base.expect("should be present");
 
     let expanded = quote! {
-        unsafe impl tvm_rt::object::IsObject for #payload_id {
+        unsafe impl #tvm_rt_crate::object::IsObject for #payload_id {
             const TYPE_KEY: &'static str = #type_key;
 
             fn as_object<'s>(&'s self) -> &'s Object {
@@ -72,9 +75,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
         }
 
         #[derive(Clone)]
-        pub struct #ref_id(Option<tvm_rt::object::ObjectPtr<#payload_id>>);
+        pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);
 
-        impl tvm_rt::object::ToObjectRef for #ref_id {
+        impl #tvm_rt_crate::object::ToObjectRef for #ref_id {
             fn to_object_ref(&self) -> ObjectRef {
                 ObjectRef(self.0.as_ref().map(|o| o.upcast()))
             }
@@ -88,25 +91,25 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
             }
         }
 
-        impl std::convert::TryFrom<tvm_rt::RetValue> for #ref_id {
-            type Error = tvm_rt::Error;
+        impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
+            type Error = #tvm_rt_crate::Error;
 
-            fn try_from(ret_val: tvm_rt::RetValue) -> Result<#ref_id, Self::Error>
{
+            fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error>
{
                 use std::convert::TryInto;
                 let oref: ObjectRef = ret_val.try_into()?;
-                let ptr = oref.0.ok_or(tvm_rt::Error::Null)?;
+                let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
                 let ptr = ptr.downcast::<#payload_id>()?;
                 Ok(#ref_id(Some(ptr)))
             }
         }
 
-        impl<'a> From<#ref_id> for tvm_rt::ArgValue<'a> {
-            fn from(object_ref: #ref_id) -> tvm_rt::ArgValue<'a> {
+        impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> {
+            fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
                 use std::ffi::c_void;
                 let object_ptr = &object_ref.0;
                 match object_ptr {
                     None => {
-                        tvm_rt::ArgValue::
+                        #tvm_rt_crate::ArgValue::
                             ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
                     }
                     Some(value) => value.clone().into()
@@ -114,40 +117,40 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream
{
             }
         }
 
-        impl<'a> From<&#ref_id> for tvm_rt::ArgValue<'a> {
-            fn from(object_ref: &#ref_id) -> tvm_rt::ArgValue<'a> {
+        impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> {
+            fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> {
                 let oref: #ref_id = object_ref.clone();
-                tvm_rt::ArgValue::<'a>::from(oref)
+                #tvm_rt_crate::ArgValue::<'a>::from(oref)
             }
         }
 
-        impl<'a> std::convert::TryFrom<tvm_rt::ArgValue<'a>> for #ref_id
{
-            type Error = tvm_rt::Error;
+        impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for
#ref_id {
+            type Error = #tvm_rt_crate::Error;
 
-            fn try_from(arg_value: tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error>
{
+            fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id,
Self::Error> {
                 use std::convert::TryInto;
                 let optr = arg_value.try_into()?;
                 Ok(#ref_id(Some(optr)))
             }
         }
 
-        impl<'a> std::convert::TryFrom<&tvm_rt::ArgValue<'a>> for #ref_id
{
-            type Error = tvm_rt::Error;
+        impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>>
for #ref_id {
+            type Error = #tvm_rt_crate::Error;
 
-            fn try_from(arg_value: &tvm_rt::ArgValue<'a>) -> Result<#ref_id,
Self::Error> {
+            fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id,
Self::Error> {
                 use std::convert::TryInto;
                 let optr = arg_value.try_into()?;
                 Ok(#ref_id(Some(optr)))
             }
         }
 
-        impl From<#ref_id> for tvm_rt::RetValue {
-            fn from(object_ref: #ref_id) -> tvm_rt::RetValue {
+        impl From<#ref_id> for #tvm_rt_crate::RetValue {
+            fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue {
                 use std::ffi::c_void;
                 let object_ptr = &object_ref.0;
                 match object_ptr {
                     None => {
-                        tvm_rt::RetValue::ObjectHandle(std::ptr::null::<c_void>() as
*mut c_void)
+                        #tvm_rt_crate::RetValue::ObjectHandle(std::ptr::null::<c_void>()
as *mut c_void)
                     }
                     Some(value) => value.clone().into()
                 }
@@ -158,14 +161,3 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
 
     TokenStream::from(expanded)
 }
-
-//  impl TryFrom<RetValue> for Var {
-//    type Error = anyhow::Error;
-
-//    fn try_from(ret_val: RetValue) -> Result<Var, Self::Error> {
-//       let oref: ObjectRef = ret_val.try_into()?;
-//       let var_ptr = oref.0.ok_or(anyhow!("null ptr"))?;
-//       let var_ptr = var_ptr.downcast::<VarNode>()?;
-//       Ok(Var(Some(var_ptr)))
-//    }
-// }
diff --git a/rust/macros/src/util.rs b/rust/macros/src/util.rs
new file mode 100644
index 0000000..c07c073
--- /dev/null
+++ b/rust/macros/src/util.rs
@@ -0,0 +1,11 @@
+use std::env;
+use quote::quote;
+use proc_macro2::TokenStream;
+
+pub fn get_tvm_rt_crate() -> TokenStream {
+    if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" {
+        quote!( crate )
+    } else {
+        quote!( tvm_rt )
+    }
+}
diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml
index 417f256..b234da1 100644
--- a/rust/tvm-rt/Cargo.toml
+++ b/rust/tvm-rt/Cargo.toml
@@ -30,7 +30,6 @@ edition = "2018"
 
 [dependencies]
 thiserror = "^1.0"
-anyhow = "^1.0"
 lazy_static = "1.1"
 ndarray = "0.12"
 num-traits = "0.2"
diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs
index ea4cffa..0c01d91 100644
--- a/rust/tvm-rt/src/context.rs
+++ b/rust/tvm-rt/src/context.rs
@@ -30,8 +30,9 @@ macro_rules! impl_device_attrs {
     };
 }
 
-external_func! {
-    fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32 as "runtime.GetDeviceAttr";
+crate::external! {
+    #[name("runtime.GetDeviceAttr")]
+    fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32;
 }
 
 impl ContextExt for Context {
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index f081258..414484d 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -60,6 +60,8 @@ pub enum Error {
     Downcast(#[from] tvm_sys::errors::ValueDowncastError),
     #[error("raw pointer passed across boundary was null")]
     Null,
+    #[error("failed to load module due to invalid path {0}")]
+    ModuleLoadPath(String),
 }
 
 impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index b0122ff..4b34bc1 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -25,7 +25,6 @@
 //!
 //! See the tests and examples repository for more examples.
 
-use anyhow::Result;
 use lazy_static::lazy_static;
 use std::convert::TryFrom;
 use std::{
@@ -44,6 +43,8 @@ use crate::errors::Error;
 use super::to_boxed_fn::ToBoxedFn;
 use super::to_function::{ToFunction, Typed};
 
+pub type Result<T> = std::result::Result<T, Error>;
+
 lazy_static! {
     static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<String, Option<Function>>>
= {
         let mut out_size = 0 as c_int;
@@ -137,7 +138,7 @@ impl Function {
     }
 
     /// Calls the function that created from `Builder`.
-    pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue>
{
+    pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue,
Error> {
         let num_args = arg_buf.len();
         let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>)
=
             arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
@@ -263,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
 /// let ret = boxed_fn(10, 20, 30).unwrap();
 /// assert_eq!(ret, 60);
 /// ```
-pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
+pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<(),
Error>
 where
     F: ToFunction<I, O>,
     F: Typed<I, O>,
@@ -274,7 +275,7 @@ where
 /// Register a function with explicit control over whether to override an existing registration
or not.
 ///
 /// See `register` for more details on how to use the registration API.
-pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_:
bool) -> Result<()>
+pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_:
bool) -> Result<(), Error>
 where
     F: ToFunction<I, O>,
     F: Typed<I, O>,
@@ -323,7 +324,7 @@ mod tests {
 
         function::register_override(constfn, "constfn".to_owned(), true).unwrap();
         let func = Function::get("constfn").unwrap();
-        let func = func.to_boxed_fn::<dyn Fn() -> Result<i32>>();
+        let func = func.to_boxed_fn::<dyn Fn() -> Result<i32, Error>>();
         let ret = func().unwrap();
         assert_eq!(ret, 10);
     }
diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs
index 9b64eb6..70a8efd 100644
--- a/rust/tvm-rt/src/lib.rs
+++ b/rust/tvm-rt/src/lib.rs
@@ -57,6 +57,8 @@ pub use tvm_sys::byte_array::ByteArray;
 pub use tvm_sys::datatype::DataType;
 use tvm_sys::ffi;
 
+pub use tvm_macros::external;
+
 // Macro to check the return call to TVM runtime shared library.
 #[macro_export]
 macro_rules! check_call {
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index 2abccc7..b8b56f4 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -26,10 +26,10 @@ use std::{
     ptr,
 };
 
-use anyhow::{anyhow, ensure, Error};
 use tvm_sys::ffi;
 
 use crate::{errors, function::Function};
+use crate::errors::Error;
 
 const ENTRY_FUNC: &str = "__tvm_main__";
 
@@ -43,12 +43,12 @@ pub struct Module {
     entry_func: Option<Function>,
 }
 
-external_func! {
-    fn runtime_enabled(target: CString) -> i32 as "runtime.RuntimeEnabled";
-}
+crate::external! {
+    #[name("runtime.RuntimeEnabled")]
+    fn runtime_enabled(target: CString) -> i32;
 
-external_func! {
-    fn load_from_file(file_name: CString, format: CString) -> Module as "runtime.ModuleLoadFromFile";
+    #[name("runtime.ModuleLoadFromFile")]
+    fn load_from_file(file_name: CString, format: CString) -> Module;
 }
 
 impl Module {
@@ -76,12 +76,13 @@ impl Module {
             query_import as c_int,
             &mut fhandle as *mut _
         ));
-        ensure!(
-            !fhandle.is_null(),
-            errors::NullHandleError {
+
+        if !fhandle.is_null() {
+            return Err(errors::NullHandleError {
                 name: name.into_string()?.to_string()
-            }
-        );
+            })
+        }
+
         Ok(Function::new(fhandle))
     }
 
@@ -97,13 +98,15 @@ impl Module {
                 .extension()
                 .unwrap_or_else(|| std::ffi::OsStr::new(""))
                 .to_str()
-                .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?,
+                .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display()))
         )?;
+
         let cpath = CString::new(
             path.as_ref()
                 .to_str()
-                .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?,
+                .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display()))
         )?;
+
         let module = load_from_file(cpath, ext)?;
         Ok(module)
     }
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 9dcf836..6ddfe16 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -3,7 +3,7 @@ use std::convert::TryInto;
 use std::ffi::CString;
 
 use crate::errors::Error;
-use crate::external_func;
+use crate::external;
 
 use tvm_sys::{ArgValue, RetValue};
 
@@ -87,10 +87,12 @@ impl<'a> From<&ObjectRef> for ArgValue<'a> {
     }
 }
 
-external_func! {
-    fn debug_print(object: ObjectRef) -> CString as "ir.DebugPrint";
+external! {
+    #[name("ir.DebugPrint")]
+    fn debug_print(object: ObjectRef) -> CString;
 }
 
-external_func! {
-    fn as_text(object: ObjectRef) -> CString as "ir.TextPrinter";
-}
+// external! {
+//     #[name("ir.TextPrinter")]
+//     fn as_text(object: ObjectRef) -> CString;
+// }
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
index ac80625..ce1cee6 100644
--- a/rust/tvm-rt/src/string.rs
+++ b/rust/tvm-rt/src/string.rs
@@ -2,7 +2,8 @@ use std::ffi::{CString, NulError};
 use std::os::raw::c_char;
 
 use super::{Object, ObjectPtr, ObjectRef};
-use crate as tvm_rt;
+use super::errors::Error;
+
 use tvm_macros::Object;
 
 #[repr(C)]
@@ -43,7 +44,7 @@ impl String {
         }
     }
 
-    pub fn to_string(&self) -> anyhow::Result<std::string::String> {
+    pub fn to_string(&self) -> Result<std::string::String, Error> {
         let string = self.to_cstring()?.into_string()?;
         Ok(string)
     }
diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs
index d6ea96d..d2dde67 100644
--- a/rust/tvm-rt/src/to_boxed_fn.rs
+++ b/rust/tvm-rt/src/to_boxed_fn.rs
@@ -25,24 +25,23 @@
 //!
 //! See the tests and examples repository for more examples.
 
-use anyhow::Result;
-
 pub use tvm_sys::{ffi, ArgValue, RetValue};
 
-use crate::Module;
+use crate::{Module, errors};
 
 use super::function::Function;
 
+type Result<T> = std::result::Result<T, errors::Error>;
+
 pub trait ToBoxedFn {
     fn to_boxed_fn(func: &'static Function) -> Box<Self>;
 }
 
 use std::convert::{TryFrom, TryInto};
 
-impl<E, O> ToBoxedFn for dyn Fn() -> Result<O>
+impl<O> ToBoxedFn for dyn Fn() -> Result<O>
 where
-    E: std::error::Error + Send + Sync + 'static,
-    O: TryFrom<RetValue, Error = E>,
+    O: TryFrom<RetValue, Error = errors::Error>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move || {
@@ -54,11 +53,10 @@ where
     }
 }
 
-impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
+impl<A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
 where
-    E: std::error::Error + Send + Sync + 'static,
     A: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
+    O: TryFrom<RetValue, Error = errors::Error>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A| {
@@ -71,12 +69,11 @@ where
     }
 }
 
-impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
+impl<A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
 where
-    E: std::error::Error + Send + Sync + 'static,
     A: Into<ArgValue<'static>>,
     B: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
+    O: TryFrom<RetValue, Error = errors::Error>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A, b: B| {
@@ -90,13 +87,12 @@ where
     }
 }
 
-impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
+impl<A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
 where
-    E: std::error::Error + Send + Sync + 'static,
     A: Into<ArgValue<'static>>,
     B: Into<ArgValue<'static>>,
     C: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
+    O: TryFrom<RetValue, Error = errors::Error>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A, b: B, c: C| {
@@ -111,14 +107,13 @@ where
     }
 }
 
-impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
+impl<A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
 where
-    E: std::error::Error + Send + Sync + 'static,
     A: Into<ArgValue<'static>>,
     B: Into<ArgValue<'static>>,
     C: Into<ArgValue<'static>>,
     D: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
+    O: TryFrom<RetValue, Error = errors::Error>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A, b: B, c: C, d: D| {
@@ -183,7 +178,7 @@ impl<'a, 'm> Builder<'a, 'm> {
         self
     }
 
-    /// Sets an output for a function that requirs a mutable output to be provided.
+    /// Sets an output for a function that requires a mutable output to be provided.
     /// See the `basics` in tests for an example.
     pub fn set_output<T>(&mut self, ret: T) -> &mut Self
     where
@@ -214,8 +209,7 @@ impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm>
{
 }
 #[cfg(test)]
 mod tests {
-    use crate::function::{self, Function};
-    use anyhow::Result;
+    use crate::function::{self, Function, Result};
 
     #[test]
     fn to_boxed_fn0() {


Mime
View raw message