tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] robo-corg commented on a change in pull request #5830: Rust Refactor Stage 4: Rewrite Rust graph runtime to use new APIs
Date Fri, 19 Jun 2020 23:05:05 GMT

robo-corg commented on a change in pull request #5830:
URL: https://github.com/apache/incubator-tvm/pull/5830#discussion_r442440869



##########
File path: rust/tvm-graph-rt/src/allocator.rs
##########
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::alloc::{self, Layout, LayoutErr};
+
+const DEFAULT_ALIGN_BYTES: usize = 4;
+
+#[derive(PartialEq, Eq)]
+pub struct Allocation {
+    layout: Layout,
+    ptr: *mut u8,
+}
+
+impl Allocation {
+    /// Allocates a chunk of memory of `size` bytes with optional alignment.
+    pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
+        let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
+        let layout = Layout::from_size_align(size, alignment)?;
+        let ptr = unsafe { alloc::alloc(layout) };
+        if ptr.is_null() {
+            alloc::handle_alloc_error(layout);
+        }
+        Ok(Self { ptr, layout })
+    }

Review comment:
       Should this return `std::mem::MaybeUninit<Allocation>` or does that matter here since it is just bytes?

##########
File path: rust/tvm-graph-rt/src/array.rs
##########
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
+
+use failure::{ensure, Error};
+use ndarray;
+use tvm_sys::{ffi::DLTensor, Context, DataType};
+
+use crate::allocator::Allocation;
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+    /// A `Storage` which owns its contained bytes.
+    Owned(Allocation),
+
+    /// A view of an existing `Storage`.
+    View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+        Ok(Storage::Owned(Allocation::new(size, align)?))
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_ptr(),
+            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+        }
+    }
+
+    pub fn size(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.size(),
+            Storage::View(slice, _) => slice.len(),
+        }
+    }
+
+    pub fn align(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.align(),
+            Storage::View(_, align) => *align,
+        }
+    }
+
+    pub fn as_ptr(&self) -> *const u8 {
+        self.as_mut_ptr() as *const _
+    }
+
+    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+    pub fn view(&self) -> Storage<'a> {
+        match self {
+            Storage::Owned(alloc) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+                self.align(),
+            ),
+            Storage::View(slice, _) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+                self.align(),
+            ),
+        }
+    }
+
+    pub fn is_owned(&self) -> bool {
+        match self {
+            Storage::Owned(_) => true,
+            _ => false,
+        }
+    }
+
+    /// Returns an owned version of this storage via cloning.
+    pub fn to_owned(&self) -> Storage<'static> {
+        let s = Storage::new(self.size(), Some(self.align())).unwrap();
+        unsafe {
+            s.as_mut_ptr()
+                .copy_from_nonoverlapping(self.as_ptr(), self.size());
+        }
+        s
+    }
+
+    /// Returns a view of the stored data.
+    pub fn as_slice(&self) -> &[u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_slice(),
+            Storage::View(slice, _) => &*slice,
+        }
+    }
+
+    /// Returns a mutable view of the stored data.
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_slice(),
+            Storage::View(slice, _) => slice,
+        }
+    }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+    fn from(data: &'d [T]) -> Self {
+        let data = unsafe {
+            slice::from_raw_parts_mut(
+                data.as_ptr() as *const u8 as *mut u8,
+                data.len() * mem::size_of::<T>() as usize,
+            )
+        };
+        Storage::View(data, mem::align_of::<T>())
+    }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+/// use std::convert::TryInto;
+/// use tvm_runtime::{call_packed, DLTensor, ArgValue, RetValue, Tensor};
+///
+/// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut a).into();
+///
+/// let tvm_fn = |args: &[ArgValue]| -> Result<RetValue, ()> { Ok(RetValue::default()) };
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {
+    /// The bytes which contain the data this `Tensor` represents.
+    pub(crate) data: Storage<'a>,
+    pub(crate) ctx: Context,
+    pub(crate) dtype: DataType,
+    pub(crate) shape: Vec<i64>,
+    // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
+    /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
+    pub(crate) strides: Option<Vec<usize>>,
+    pub(crate) byte_offset: isize,
+    /// The number of elements in the `Tensor`.
+    pub(crate) size: usize,
+}
+
+unsafe impl<'a> Send for Tensor<'a> {}
+
+impl<'a> Tensor<'a> {
+    pub fn shape(&self) -> Vec<i64> {
+        self.shape.clone()
+    }
+
+    pub fn data(&self) -> &Storage {
+        &self.data
+    }
+
+    pub fn data_mut(&mut self) -> &'a mut Storage {
+        &mut self.data
+    }

Review comment:
       I think this is unsafe since you can swap or replace the return storage now to one that has the wrong size and trigger UB.

##########
File path: rust/tvm-graph-rt/src/allocator.rs
##########
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::alloc::{self, Layout, LayoutErr};
+
+const DEFAULT_ALIGN_BYTES: usize = 4;
+
+#[derive(PartialEq, Eq)]
+pub struct Allocation {
+    layout: Layout,

Review comment:
       I assume I will find out why we need to track alignment?

##########
File path: rust/tvm-graph-rt/src/graph.rs
##########
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+
+use failure::{ensure, format_err, Error};
+use itertools::izip;
+use nom::{
+    character::complete::{alpha1, digit1},
+    complete, count, do_parse, length_count, map, named,
+    number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
+    opt, tag, take, tuple,
+};
+
+use serde::{Deserialize, Serialize};
+use serde_json;
+
+use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt};
+
+use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};
+
+use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+
+// @see `kTVMNDArrayMagic` in `ndarray.h`
+const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
+// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
+const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7;
+
+/// A TVM computation graph.
+///
+/// # Examples
+///
+/// ```norun
+/// let graph_json = fs::read_to_string("graph.json").unwrap();
+/// let graph = Graph::try_from(&graph_json).unwrap();
+/// ```
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Graph {
+    pub nodes: Vec<Node>,
+    pub arg_nodes: Vec<usize>,
+    pub heads: Vec<Entry>,
+    pub node_row_ptr: Option<Vec<usize>>,
+    pub attrs: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Entry {
+    pub id: usize,
+    pub index: usize,
+    pub version: usize,
+}
+
+impl Graph {
+    fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
+        self.node_row_ptr
+            .as_ref()
+            .map(|nrp| nrp[entry.id] + entry.index)
+            .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
+    }
+
+    /// Attempt to deserialize a JSON attribute to a type `T`.
+    fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
+        Ok(serde_json::from_value::<T>(
+            self.attrs
+                .as_ref()
+                .ok_or(GraphFormatError::MissingField("attrs"))?
+                .get(attr)
+                .ok_or_else(|| {
+                    GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
+                })?
+                .to_owned(),
+        )
+        .map_err(|err| GraphFormatError::Parse(err.into()))?)
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Node {
+    pub op: String,
+    pub name: String,
+    pub inputs: Vec<Entry>,
+    pub attrs: Option<HashMap<String, String>>,
+    pub control_deps: Option<Vec<Entry>>,
+}
+
+struct NodeAttrs {
+    func_name: String,
+    num_outputs: usize,
+    flatten_data: bool,
+}
+
+macro_rules! get_node_attr {
+    ($node:expr, $attrs:ident, $attr:literal) => {
+        $attrs
+            .get($attr)
+            .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
+    };
+}
+
+impl Node {
+    fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
+        let attrs = self
+            .attrs
+            .as_ref()
+            .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
+        Ok(NodeAttrs {
+            func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
+            num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
+            flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
+        })
+    }
+}
+
+impl<'a> TryFrom<&'a String> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &String) -> Result<Self, self::Error> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+impl<'a> TryFrom<&'a str> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+/// A executor for a TVM computation graph.
+///
+/// # Examples
+///
+/// ```norun
+/// use ndarray::Array;
+///
+/// let syslib = SystemLibModule::default(); // a provider of TVM functions
+///
+/// let mut params_bytes = Vec::new();
+/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
+/// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
+///
+/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
+///
+/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+/// exec.load_params(params);
+///
+/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// exec.set_input("data", x.into());
+/// exec.run();
+/// let output = exec.get_output(0).unwrap();
+///
+/// println!("{:#?}", Array::try_from(output).unwrap());
+/// ```
+pub struct GraphExecutor<'m, 't> {
+    graph: Graph,
+    op_execs: Vec<Box<dyn Fn() + 'm>>,
+    tensors: Vec<Tensor<'t>>,
+}
+
+unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}

Review comment:
       What about GraphExecutor is not Send?

##########
File path: rust/tvm-graph-rt/src/array.rs
##########
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
+
+use failure::{ensure, Error};
+use ndarray;
+use tvm_sys::{ffi::DLTensor, Context, DataType};
+
+use crate::allocator::Allocation;
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+    /// A `Storage` which owns its contained bytes.
+    Owned(Allocation),
+
+    /// A view of an existing `Storage`.
+    View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+        Ok(Storage::Owned(Allocation::new(size, align)?))
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_ptr(),
+            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+        }
+    }
+
+    pub fn size(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.size(),
+            Storage::View(slice, _) => slice.len(),
+        }
+    }
+
+    pub fn align(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.align(),
+            Storage::View(_, align) => *align,
+        }
+    }
+
+    pub fn as_ptr(&self) -> *const u8 {
+        self.as_mut_ptr() as *const _
+    }
+
+    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+    pub fn view(&self) -> Storage<'a> {
+        match self {
+            Storage::Owned(alloc) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+                self.align(),
+            ),
+            Storage::View(slice, _) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+                self.align(),
+            ),
+        }
+    }
+
+    pub fn is_owned(&self) -> bool {
+        match self {
+            Storage::Owned(_) => true,
+            _ => false,
+        }
+    }
+
+    /// Returns an owned version of this storage via cloning.
+    pub fn to_owned(&self) -> Storage<'static> {
+        let s = Storage::new(self.size(), Some(self.align())).unwrap();
+        unsafe {
+            s.as_mut_ptr()
+                .copy_from_nonoverlapping(self.as_ptr(), self.size());
+        }
+        s
+    }
+
+    /// Returns a view of the stored data.
+    pub fn as_slice(&self) -> &[u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_slice(),
+            Storage::View(slice, _) => &*slice,
+        }
+    }
+
+    /// Returns a mutable view of the stored data.
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_slice(),
+            Storage::View(slice, _) => slice,
+        }
+    }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+    fn from(data: &'d [T]) -> Self {
+        let data = unsafe {
+            slice::from_raw_parts_mut(
+                data.as_ptr() as *const u8 as *mut u8,
+                data.len() * mem::size_of::<T>() as usize,
+            )
+        };
+        Storage::View(data, mem::align_of::<T>())
+    }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+/// use std::convert::TryInto;
+/// use tvm_runtime::{call_packed, DLTensor, ArgValue, RetValue, Tensor};
+///
+/// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut a).into();
+///
+/// let tvm_fn = |args: &[ArgValue]| -> Result<RetValue, ()> { Ok(RetValue::default()) };
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {

Review comment:
       Does it make sense to have an owned Tensor and a TensorRef type? I guess that can get added later.

##########
File path: rust/tvm-graph-rt/.travis.yml
##########
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+language: rust
+rust:
+  - nightly

Review comment:
       Nightly be removed right?
   ```suggestion
   ```

##########
File path: rust/tvm-graph-rt/src/threading.rs
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    os::raw::{c_int, c_void},
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc, Barrier,
+    },
+    thread::{self, JoinHandle},
+};
+
+#[cfg(not(target_arch = "wasm32"))]
+use std::env;
+
+use crossbeam::channel::{bounded, Receiver, Sender};
+use tvm_sys::ffi::TVMParallelGroupEnv;
+
+pub(crate) type FTVMParallelLambda =
+    extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
+
+/// Holds a parallel job request made by a TVM library function.
+struct Job {
+    cb: FTVMParallelLambda,
+    cdata: *const c_void,
+    req_num_tasks: usize,
+    pending: Arc<AtomicUsize>,
+}
+
+impl Job {

Review comment:
       Any reason not to use rayon for this? I think you can tell it to spawn using tvm's thread pool: https://docs.rs/rayon/1.3.1/rayon/struct.ThreadPoolBuilder.html

##########
File path: rust/tvm-graph-rt/src/workspace.rs
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    cell::RefCell,
+    os::raw::{c_int, c_void},
+    ptr,
+};
+
+use failure::{format_err, Error};
+
+use crate::allocator::Allocation;
+
+const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
+
+pub fn remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T> {
+    let pos = vec.iter().position(|x| *x == *item)?;
+    Some(vec.remove(pos))
+}
+
+struct WorkspacePool {
+    workspaces: Vec<Allocation>,
+    free: Vec<usize>,
+    in_use: Vec<usize>,
+}
+
+impl WorkspacePool {
+    fn new() -> Self {
+        WorkspacePool {
+            workspaces: Vec::new(),
+            free: Vec::new(),
+            in_use: Vec::new(),
+        }
+    }
+
+    fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
+        self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
+        self.in_use.push(self.workspaces.len() - 1);
+        Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
+    }
+
+    fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
+        if self.free.is_empty() {
+            return self.alloc_new(size);
+        }
+        let idx = self
+            .free
+            .iter()
+            .fold(None, |cur_ws_idx: Option<usize>, &idx| {
+                let ws_size = self.workspaces[idx].size();
+                if ws_size < size {
+                    return cur_ws_idx;
+                }

Review comment:
       This seems like you could end up with some really extreme over allocation if you have a combination of large and small tensors.

##########
File path: rust/tvm-graph-rt/src/array.rs
##########
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
+
+use failure::{ensure, Error};
+use ndarray;
+use tvm_sys::{ffi::DLTensor, Context, DataType};
+
+use crate::allocator::Allocation;
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+    /// A `Storage` which owns its contained bytes.
+    Owned(Allocation),
+
+    /// A view of an existing `Storage`.
+    View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+        Ok(Storage::Owned(Allocation::new(size, align)?))
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_ptr(),
+            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+        }
+    }
+
+    pub fn size(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.size(),
+            Storage::View(slice, _) => slice.len(),
+        }
+    }
+
+    pub fn align(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.align(),
+            Storage::View(_, align) => *align,
+        }
+    }
+
+    pub fn as_ptr(&self) -> *const u8 {
+        self.as_mut_ptr() as *const _
+    }
+
+    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+    pub fn view(&self) -> Storage<'a> {
+        match self {
+            Storage::Owned(alloc) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+                self.align(),
+            ),
+            Storage::View(slice, _) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+                self.align(),
+            ),
+        }
+    }
+
+    pub fn is_owned(&self) -> bool {
+        match self {
+            Storage::Owned(_) => true,
+            _ => false,
+        }
+    }
+
+    /// Returns an owned version of this storage via cloning.
+    pub fn to_owned(&self) -> Storage<'static> {
+        let s = Storage::new(self.size(), Some(self.align())).unwrap();
+        unsafe {
+            s.as_mut_ptr()
+                .copy_from_nonoverlapping(self.as_ptr(), self.size());
+        }
+        s
+    }
+
+    /// Returns a view of the stored data.
+    pub fn as_slice(&self) -> &[u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_slice(),
+            Storage::View(slice, _) => &*slice,
+        }
+    }
+
+    /// Returns a mutable view of the stored data.
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_slice(),
+            Storage::View(slice, _) => slice,
+        }
+    }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+    fn from(data: &'d [T]) -> Self {
+        let data = unsafe {
+            slice::from_raw_parts_mut(
+                data.as_ptr() as *const u8 as *mut u8,
+                data.len() * mem::size_of::<T>() as usize,
+            )
+        };
+        Storage::View(data, mem::align_of::<T>())
+    }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+/// use std::convert::TryInto;
+/// use tvm_runtime::{call_packed, DLTensor, ArgValue, RetValue, Tensor};
+///
+/// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut a).into();
+///
+/// let tvm_fn = |args: &[ArgValue]| -> Result<RetValue, ()> { Ok(RetValue::default()) };
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {
+    /// The bytes which contain the data this `Tensor` represents.
+    pub(crate) data: Storage<'a>,
+    pub(crate) ctx: Context,
+    pub(crate) dtype: DataType,
+    pub(crate) shape: Vec<i64>,
+    // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
+    /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
+    pub(crate) strides: Option<Vec<usize>>,
+    pub(crate) byte_offset: isize,
+    /// The number of elements in the `Tensor`.
+    pub(crate) size: usize,

Review comment:
       I would make these pub(self) or remove pub entirely since it looks like you have unsafe code using them.

##########
File path: rust/tvm-graph-rt/src/array.rs
##########
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
+
+use failure::{ensure, Error};
+use ndarray;
+use tvm_sys::{ffi::DLTensor, Context, DataType};
+
+use crate::allocator::Allocation;
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+    /// A `Storage` which owns its contained bytes.
+    Owned(Allocation),
+
+    /// A view of an existing `Storage`.
+    View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+        Ok(Storage::Owned(Allocation::new(size, align)?))
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_ptr(),
+            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+        }
+    }
+
+    pub fn size(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.size(),
+            Storage::View(slice, _) => slice.len(),
+        }
+    }
+
+    pub fn align(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.align(),
+            Storage::View(_, align) => *align,
+        }
+    }
+
+    pub fn as_ptr(&self) -> *const u8 {
+        self.as_mut_ptr() as *const _
+    }
+
+    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+    pub fn view(&self) -> Storage<'a> {
+        match self {
+            Storage::Owned(alloc) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+                self.align(),
+            ),
+            Storage::View(slice, _) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+                self.align(),
+            ),
+        }
+    }
+
+    pub fn is_owned(&self) -> bool {
+        match self {
+            Storage::Owned(_) => true,
+            _ => false,
+        }
+    }
+
+    /// Returns an owned version of this storage via cloning.
+    pub fn to_owned(&self) -> Storage<'static> {
+        let s = Storage::new(self.size(), Some(self.align())).unwrap();
+        unsafe {
+            s.as_mut_ptr()
+                .copy_from_nonoverlapping(self.as_ptr(), self.size());
+        }
+        s
+    }
+
+    /// Returns a view of the stored data.
+    pub fn as_slice(&self) -> &[u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_slice(),
+            Storage::View(slice, _) => &*slice,
+        }
+    }
+
+    /// Returns a mutable view of the stored data.
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_slice(),
+            Storage::View(slice, _) => slice,
+        }
+    }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+    fn from(data: &'d [T]) -> Self {
+        let data = unsafe {
+            slice::from_raw_parts_mut(
+                data.as_ptr() as *const u8 as *mut u8,
+                data.len() * mem::size_of::<T>() as usize,
+            )
+        };
+        Storage::View(data, mem::align_of::<T>())
+    }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+/// use std::convert::TryInto;
+/// use tvm_runtime::{call_packed, DLTensor, ArgValue, RetValue, Tensor};
+///
+/// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut a).into();
+///
+/// let tvm_fn = |args: &[ArgValue]| -> Result<RetValue, ()> { Ok(RetValue::default()) };
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {
+    /// The bytes which contain the data this `Tensor` represents.
+    pub(crate) data: Storage<'a>,
+    pub(crate) ctx: Context,
+    pub(crate) dtype: DataType,
+    pub(crate) shape: Vec<i64>,
+    // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
+    /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
+    pub(crate) strides: Option<Vec<usize>>,
+    pub(crate) byte_offset: isize,
+    /// The number of elements in the `Tensor`.
+    pub(crate) size: usize,
+}
+
+unsafe impl<'a> Send for Tensor<'a> {}
+
+impl<'a> Tensor<'a> {
+    pub fn shape(&self) -> Vec<i64> {
+        self.shape.clone()
+    }
+
+    pub fn data(&self) -> &Storage {
+        &self.data
+    }
+
+    pub fn data_mut(&mut self) -> &'a mut Storage {
+        &mut self.data
+    }

Review comment:
       I think the Tensor<'static> returned below also breaks with this.

##########
File path: rust/tvm-graph-rt/src/workspace.rs
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    cell::RefCell,
+    os::raw::{c_int, c_void},
+    ptr,
+};
+
+use failure::{format_err, Error};
+
+use crate::allocator::Allocation;
+
+const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
+
+pub fn remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T> {
+    let pos = vec.iter().position(|x| *x == *item)?;
+    Some(vec.remove(pos))
+}
+
+struct WorkspacePool {
+    workspaces: Vec<Allocation>,
+    free: Vec<usize>,
+    in_use: Vec<usize>,

Review comment:
       Could be a good use for https://docs.rs/hibitset/0.6.3/hibitset/

##########
File path: rust/tvm-sys/src/datatype.rs
##########
@@ -73,6 +73,18 @@ impl DataType {
     pub fn lanes(&self) -> usize {
         self.lanes as usize
     }
+
+    pub const fn int(bits: u8, lanes: u16) -> DataType {
+        DataType::new(DL_INT_CODE, bits, lanes)
+    }
+
+    pub const fn float(bits: u8, lanes: u16) -> DataType {
+        DataType::new(DL_FLOAT_CODE, bits, lanes)
+    }
+
+    pub const fn uint(bits: u8, lanes: u16) -> DataType {
+        DataType::new(DL_FLOAT_CODE, bits, lanes)

Review comment:
       Should this be `DL_FLOAT_CODE`?
   ```suggestion
           DataType::new(DL_UINT_CODE, bits, lanes)
   ```

##########
File path: rust/tvm-graph-rt/src/graph.rs
##########
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+
+use failure::{ensure, format_err, Error};
+use itertools::izip;
+use nom::{
+    character::complete::{alpha1, digit1},
+    complete, count, do_parse, length_count, map, named,
+    number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
+    opt, tag, take, tuple,
+};
+
+use serde::{Deserialize, Serialize};
+use serde_json;
+
+use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt};
+
+use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};
+
+use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+
+// @see `kTVMNDArrayMagic` in `ndarray.h`
+const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
+// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
+const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7;
+
+/// A TVM computation graph.
+///
+/// # Examples
+///
+/// ```norun
+/// let graph_json = fs::read_to_string("graph.json").unwrap();
+/// let graph = Graph::try_from(&graph_json).unwrap();
+/// ```
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Graph {
+    pub nodes: Vec<Node>,
+    pub arg_nodes: Vec<usize>,
+    pub heads: Vec<Entry>,
+    pub node_row_ptr: Option<Vec<usize>>,
+    pub attrs: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Entry {
+    pub id: usize,
+    pub index: usize,
+    pub version: usize,
+}
+
+impl Graph {
+    fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
+        self.node_row_ptr
+            .as_ref()
+            .map(|nrp| nrp[entry.id] + entry.index)
+            .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
+    }
+
+    /// Attempt to deserialize a JSON attribute to a type `T`.
+    fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
+        Ok(serde_json::from_value::<T>(
+            self.attrs
+                .as_ref()
+                .ok_or(GraphFormatError::MissingField("attrs"))?
+                .get(attr)
+                .ok_or_else(|| {
+                    GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
+                })?
+                .to_owned(),
+        )
+        .map_err(|err| GraphFormatError::Parse(err.into()))?)
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Node {
+    pub op: String,
+    pub name: String,
+    pub inputs: Vec<Entry>,
+    pub attrs: Option<HashMap<String, String>>,
+    pub control_deps: Option<Vec<Entry>>,
+}
+
+struct NodeAttrs {
+    func_name: String,
+    num_outputs: usize,
+    flatten_data: bool,
+}
+
+macro_rules! get_node_attr {
+    ($node:expr, $attrs:ident, $attr:literal) => {
+        $attrs
+            .get($attr)
+            .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
+    };
+}
+
+impl Node {
+    fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
+        let attrs = self
+            .attrs
+            .as_ref()
+            .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
+        Ok(NodeAttrs {
+            func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
+            num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
+            flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
+        })
+    }
+}
+
+impl<'a> TryFrom<&'a String> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &String) -> Result<Self, self::Error> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+impl<'a> TryFrom<&'a str> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+/// A executor for a TVM computation graph.
+///
+/// # Examples
+///
+/// ```norun
+/// use ndarray::Array;
+///
+/// let syslib = SystemLibModule::default(); // a provider of TVM functions
+///
+/// let mut params_bytes = Vec::new();
+/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
+/// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
+///
+/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
+///
+/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+/// exec.load_params(params);
+///
+/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// exec.set_input("data", x.into());
+/// exec.run();
+/// let output = exec.get_output(0).unwrap();
+///
+/// println!("{:#?}", Array::try_from(output).unwrap());
+/// ```
+pub struct GraphExecutor<'m, 't> {
+    graph: Graph,
+    op_execs: Vec<Box<dyn Fn() + 'm>>,
+    tensors: Vec<Tensor<'t>>,
+}
+
+unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
+
+impl<'m, 't> GraphExecutor<'m, 't> {
+    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
+        let tensors = Self::setup_storages(&graph)?;
+        Ok(GraphExecutor {
+            op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
+            tensors,
+            graph,
+        })
+    }
+
+    /// Runs the computation graph.
+    pub fn run(&mut self) {
+        self.op_execs.iter().for_each(|op_exec| {
+            op_exec();
+        });
+    }
+
+    /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
+    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
+        let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
+        let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
+        let dtypes = graph
+            .get_attr::<(String, Vec<String>)>("dltype")?
+            .1
+            .iter()
+            .map(|dltype| {
+                if let Ok((_, dtype)) = tvm_str_to_type(dltype) {
+                    Ok(dtype)
+                } else {
+                    Err(GraphFormatError::InvalidDLType(dltype.to_string()))
+                }
+            })
+            .collect::<Result<Vec<DataType>, GraphFormatError>>()?;
+
+        let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
+        let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
+        for (i, &storage_id) in storage_ids.iter().enumerate() {
+            let dtype_size = (dtypes[i].bits() * dtypes[i].lanes()) >> 3;
+            let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
+            storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
+        }
+
+        let mut storages: Vec<Storage> = storage_num_bytes
+            .into_iter()
+            .map(|nbytes| Storage::new(nbytes, align))
+            .collect::<Result<Vec<Storage>, Error>>()?;
+
+        let tensors = izip!(storage_ids, shapes, dtypes)
+            .map(|(storage_id, shape, dtype)| {
+                let storage = storages[storage_id].view();
+                Tensor {
+                    data: mem::replace(&mut storages[storage_id], storage),
+                    ctx: Context::default(),
+                    dtype,
+                    size: shape.iter().product::<i64>() as usize,
+                    shape,
+                    strides: None,
+                    byte_offset: 0,
+                }
+            })
+            .collect();
+
+        Ok(tensors)
+    }
+
+    /// Creates closures which represent the computation performed by this graph.
+    fn setup_op_execs<M: 'm + Module>(
+        graph: &Graph,
+        lib: &'m M,
+        tensors: &[Tensor<'t>],
+    ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {

Review comment:
       Maybe make this execs a function of `Tensor<'t>`  or `for <'a> Tensor<'a>`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message