tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jroe...@apache.org
Subject [incubator-tvm] 03/04: Finish removing anyhow and work with new external! macro
Date Mon, 08 Jun 2020 21:01:25 GMT
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch rust-tvm
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 0c55c39477979c75f2bf3e2e9974d90fde74fa26
Author: Jared Roesch <jroesch@octoml.ai>
AuthorDate: Mon Jun 8 13:56:28 2020 -0700

    Finish removing anyhow and work with new external! macro
---
 rust/tvm-rt/src/context.rs     | 12 ++++++++----
 rust/tvm-rt/src/errors.rs      | 14 ++++++++------
 rust/tvm-rt/src/function.rs    | 12 ++++++------
 rust/tvm-rt/src/module.rs      | 10 +++++-----
 rust/tvm-rt/src/ndarray.rs     |  2 +-
 rust/tvm-rt/src/to_boxed_fn.rs | 29 ++++++++++++++++-------------
 rust/tvm-rt/src/to_function.rs | 30 +++++++++++++++---------------
 7 files changed, 59 insertions(+), 50 deletions(-)

diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs
index 0c01d91..b1bdab5 100644
--- a/rust/tvm-rt/src/context.rs
+++ b/rust/tvm-rt/src/context.rs
@@ -1,13 +1,17 @@
-pub use tvm_sys::context::*;
-use tvm_sys::ffi;
 
 use std::os::raw::c_void;
 use std::ptr;
 
+use crate::errors::Error;
+
+use tvm_sys::ffi;
+
+pub use tvm_sys::context::*;
+
 trait ContextExt {
     /// Checks whether the context exists or not.
     fn exist(&self) -> bool;
-    fn sync(&self) -> anyhow::Result<()>;
+    fn sync(&self) -> Result<(), Error>;
     fn max_threads_per_block(&self) -> isize;
     fn warp_size(&self) -> isize;
     fn max_shared_memory_per_block(&self) -> isize;
@@ -44,7 +48,7 @@ impl ContextExt for Context {
     }
 
     /// Synchronize the context stream.
-    fn sync(&self) -> anyhow::Result<()> {
+    fn sync(&self) -> Result<(), Error> {
         check_call!(ffi::TVMSynchronize(
             self.device_type as i32,
             self.device_id as i32,
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 414484d..197c875 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -21,12 +21,6 @@ use crate::DataType;
 use thiserror::Error;
 
 #[derive(Debug, Error)]
-#[error("Handle `{name}` is null.")]
-pub struct NullHandleError {
-    pub name: String,
-}
-
-#[derive(Debug, Error)]
 #[error("Function was not set in `function::Builder`")]
 pub struct FunctionNotFoundError;
 
@@ -62,6 +56,14 @@ pub enum Error {
     Null,
     #[error("failed to load module due to invalid path {0}")]
     ModuleLoadPath(String),
+    #[error("failed to convert String into CString due to embedded nul character")]
+    ToCString(#[from] std::ffi::NulError),
+    #[error("failed to convert CString into String")]
+    FromCString(#[from] std::ffi::IntoStringError),
+    #[error("Handle `{0}` is null.")]
+    NullHandle(String),
+    #[error("{0}")]
+    NDArray(#[from] NDArrayError),
 }
 
 impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 4b34bc1..cca918a 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -138,7 +138,7 @@ impl Function {
     }
 
     /// Calls the function that created from `Builder`.
-    pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue,
Error> {
+    pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue>
{
         let num_args = arg_buf.len();
         let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>)
=
             arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
@@ -192,7 +192,7 @@ impl From<Function> for RetValue {
 impl TryFrom<RetValue> for Function {
     type Error = Error;
 
-    fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> {
+    fn try_from(ret_value: RetValue) -> Result<Function> {
         match ret_value {
             RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
             _ => Err(Error::downcast(
@@ -212,7 +212,7 @@ impl<'a> From<Function> for ArgValue<'a> {
 impl<'a> TryFrom<ArgValue<'a>> for Function {
     type Error = Error;
 
-    fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error>
{
+    fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
         match arg_value {
             ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
             _ => Err(Error::downcast(
@@ -226,7 +226,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
 impl<'a> TryFrom<&ArgValue<'a>> for Function {
     type Error = Error;
 
-    fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error>
{
+    fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
         match arg_value {
             ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
             _ => Err(Error::downcast(
@@ -264,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
 /// let ret = boxed_fn(10, 20, 30).unwrap();
 /// assert_eq!(ret, 60);
 /// ```
-pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<(),
Error>
+pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
 where
     F: ToFunction<I, O>,
     F: Typed<I, O>,
@@ -275,7 +275,7 @@ where
 /// Register a function with explicit control over whether to override an existing registration
or not.
 ///
 /// See `register` for more details on how to use the registration API.
-pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_:
bool) -> Result<(), Error>
+pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_:
bool) -> Result<()>
 where
     F: ToFunction<I, O>,
     F: Typed<I, O>,
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index b8b56f4..c161af5 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -78,9 +78,9 @@ impl Module {
         ));
 
         if !fhandle.is_null() {
-            return Err(errors::NullHandleError {
-                name: name.into_string()?.to_string()
-            })
+            return Err(errors::Error::NullHandle(
+                name.into_string()?.to_string()
+            ));
         }
 
         Ok(Function::new(fhandle))
@@ -98,13 +98,13 @@ impl Module {
                 .extension()
                 .unwrap_or_else(|| std::ffi::OsStr::new(""))
                 .to_str()
-                .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display()))
+                .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?
         )?;
 
         let cpath = CString::new(
             path.as_ref()
                 .to_str()
-                .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display()))
+                .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?
         )?;
 
         let module = load_from_file(cpath, ext)?;
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index 593154d..9a17502 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -147,7 +147,7 @@ impl NDArray {
     }
 
     /// Shows whether the underlying ndarray is contiguous in memory or not.
-    pub fn is_contiguous(&self) -> anyhow::Result<bool> {
+    pub fn is_contiguous(&self) -> Result<bool, crate::errors::Error> {
         Ok(match self.strides() {
             None => true,
             Some(strides) => {
diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs
index d2dde67..12e4351 100644
--- a/rust/tvm-rt/src/to_boxed_fn.rs
+++ b/rust/tvm-rt/src/to_boxed_fn.rs
@@ -29,9 +29,7 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
 
 use crate::{Module, errors};
 
-use super::function::Function;
-
-type Result<T> = std::result::Result<T, errors::Error>;
+use super::function::{Function, Result};
 
 pub trait ToBoxedFn {
     fn to_boxed_fn(func: &'static Function) -> Box<Self>;
@@ -39,9 +37,10 @@ pub trait ToBoxedFn {
 
 use std::convert::{TryFrom, TryInto};
 
-impl<O> ToBoxedFn for dyn Fn() -> Result<O>
+impl<E, O> ToBoxedFn for dyn Fn() -> Result<O>
 where
-    O: TryFrom<RetValue, Error = errors::Error>,
+    errors::Error: From<E>,
+    O: TryFrom<RetValue, Error = E>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move || {
@@ -53,10 +52,11 @@ where
     }
 }
 
-impl<A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
+impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
 where
+    errors::Error: From<E>,
     A: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = errors::Error>,
+    O: TryFrom<RetValue, Error = E>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A| {
@@ -69,11 +69,12 @@ where
     }
 }
 
-impl<A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
+impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
 where
+    errors::Error: From<E>,
     A: Into<ArgValue<'static>>,
     B: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = errors::Error>,
+    O: TryFrom<RetValue, Error = E>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A, b: B| {
@@ -87,12 +88,13 @@ where
     }
 }
 
-impl<A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
+impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
 where
+    errors::Error: From<E>,
     A: Into<ArgValue<'static>>,
     B: Into<ArgValue<'static>>,
     C: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = errors::Error>,
+    O: TryFrom<RetValue, Error = E>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A, b: B, c: C| {
@@ -107,13 +109,14 @@ where
     }
 }
 
-impl<A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
+impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
 where
+    errors::Error: From<E>,
     A: Into<ArgValue<'static>>,
     B: Into<ArgValue<'static>>,
     C: Into<ArgValue<'static>>,
     D: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = errors::Error>,
+    O: TryFrom<RetValue, Error = E>,
 {
     fn to_boxed_fn(func: &'static Function) -> Box<Self> {
         Box::new(move |a: A, b: B, c: C, d: D| {
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index 0527b0c..9d8065c 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -32,7 +32,7 @@ use std::{
     ptr, slice,
 };
 
-use super::Function;
+use super::{Function, function::Result};
 use crate::errors::Error;
 
 pub use tvm_sys::{ffi, ArgValue, RetValue};
@@ -46,20 +46,20 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
 ///
 /// And the implementation of it to `ToFunction`.
 pub trait Typed<I, O> {
-    fn args(i: &[ArgValue<'static>]) -> Result<I, Error>;
+    fn args(i: &[ArgValue<'static>]) -> Result<I>;
     fn ret(o: O) -> RetValue;
 }
 
-impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>>
for F
+impl<'a, F> Typed<&'a [ArgValue<'static>], Result<RetValue>>
for F
 where
-    F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>,
+    F: Fn(&'a [ArgValue]) -> Result<RetValue>,
 {
-    fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>],
Error> {
+    fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>]>
{
         // this is BAD but just hacking for time being
         Ok(unsafe { std::mem::transmute(args) })
     }
 
-    fn ret(ret_value: anyhow::Result<RetValue>) -> RetValue {
+    fn ret(ret_value: Result<RetValue>) -> RetValue {
         ret_value.unwrap()
     }
 }
@@ -68,7 +68,7 @@ impl<F, O: Into<RetValue>> Typed<(), O> for F
 where
     F: Fn() -> O,
 {
-    fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<(), Error>
{
+    fn args(_args: &[ArgValue<'static>]) -> Result<()> {
         debug_assert!(_args.len() == 0);
         Ok(())
     }
@@ -84,7 +84,7 @@ where
     Error: From<E>,
     A: TryFrom<ArgValue<'static>, Error = E>,
 {
-    fn args(args: &[ArgValue<'static>]) -> Result<(A,), Error> {
+    fn args(args: &[ArgValue<'static>]) -> Result<(A,)> {
         debug_assert!(args.len() == 1);
         let a: A = args[0].clone().try_into()?;
         Ok((a,))
@@ -102,7 +102,7 @@ where
     A: TryFrom<ArgValue<'static>, Error = E>,
     B: TryFrom<ArgValue<'static>, Error = E>,
 {
-    fn args(args: &[ArgValue<'static>]) -> Result<(A, B), Error> {
+    fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> {
         debug_assert!(args.len() == 2);
         let a: A = args[0].clone().try_into()?;
         let b: B = args[1].clone().try_into()?;
@@ -122,7 +122,7 @@ where
     B: TryFrom<ArgValue<'static>, Error = E>,
     C: TryFrom<ArgValue<'static>, Error = E>,
 {
-    fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C), Error> {
+    fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> {
         debug_assert!(args.len() == 3);
         let a: A = args[0].clone().try_into()?;
         let b: B = args[1].clone().try_into()?;
@@ -140,7 +140,7 @@ pub trait ToFunction<I, O>: Sized {
 
     fn into_raw(self) -> *mut Self::Handle;
 
-    fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue,
Error>
+    fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue>
     where
         Self: Typed<I, O>;
 
@@ -242,7 +242,7 @@ pub trait ToFunction<I, O>: Sized {
 // }
 
 // impl Typed<&[ArgValue<'static>], ()> for RawFunction {
-//     fn args(i: &[ArgValue<'static>]) -> anyhow::Result<&[ArgValue<'static>]>
{
+//     fn args(i: &[ArgValue<'static>]) -> Result<&[ArgValue<'static>]>
{
 //         Ok(i)
 //     }
 
@@ -279,7 +279,7 @@ where
         Box::into_raw(ptr)
     }
 
-    fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue,
Error>
+    fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue>
     where
         F: Typed<(), O>,
     {
@@ -302,7 +302,7 @@ macro_rules! to_function_instance {
                 Box::into_raw(ptr)
             }
 
-            fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) ->
Result<RetValue, Error> where F: Typed<($($param,)+), O> {
+            fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) ->
Result<RetValue> where F: Typed<($($param,)+), O> {
                 // Ideally we shouldn't need to clone, probably doesn't really matter.
                 let args = F::args(args)?;
                 let out = unsafe {
@@ -338,7 +338,7 @@ mod tests {
         f.to_function()
     }
 
-    // fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result<RetValue>
{
+    // fn func_args(args: &[ArgValue<'static>]) -> Result<RetValue> {
     //     Ok(10.into())
     // }
 


Mime
View raw message