tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [tvm] areusch commented on a change in pull request #8650: Refactor AOT Test Utils parameters into object
Date Mon, 09 Aug 2021 19:21:25 GMT

areusch commented on a change in pull request #8650:
URL: https://github.com/apache/tvm/pull/8650#discussion_r685455522



##########
File path: tests/python/relay/aot/aot_test_utils.py
##########
@@ -39,6 +40,30 @@
 _LOG = logging.getLogger(__name__)
 
 
+class AOTTestNetwork(NamedTuple):
+    """Class to describe a network under test
+
+    Parameters
+    ----------
+    module: tvm.IRModule
+        IRModule to generate AOT executor for
+    inputs: Dict[str, np.array]
+        Dict of input names to value arrays
+    outputs: List[np.array]
+        Ordered list of output value arrays
+    name: str
+        Name to use for this network
+    params: Optional[Dict[str, np.array]]
+        Dict of parameter names to value arrays
+    """
+
+    module: tvm.IRModule

Review comment:
       iirc i think we are still on 3.6 in the ci- images--can you check that?

##########
File path: tests/python/relay/aot/aot_test_utils.py
##########
@@ -430,132 +436,60 @@ def compile_and_run(
         os.path.join(include_path, "crt_config.h"),
     )
 
-    for key in inputs:
-        create_header_file(
-            f'{mangle_name(mod_name, "input_data")}_{key}',
-            inputs[key],
-            os.path.join(base_path, "include"),
-        )
-
-    for i in range(len(output_list)):
-        create_header_file(
-            f'{mangle_name(mod_name,"output_data")}{i}',
-            np.zeros(output_list[i].shape, output_list[i].dtype),
-            os.path.join(base_path, "include"),
-        )
-        create_header_file(
-            f'{mangle_name(mod_name, "expected_output_data")}{i}',
-            output_list[i],
-            os.path.join(base_path, "include"),
-        )
-
-    create_main(
-        "test.c",
-        {mod_name: inputs},
-        {mod_name: output_list},
-        build_path,
-        interface_api,
-        workspace_bytes,
-    )
-
-    # Verify that compiles fine
-    file_dir = os.path.dirname(os.path.abspath(__file__))
-    codegen_path = os.path.join(base_path, "codegen")
-    makefile = os.path.join(file_dir, "aot_test.mk")
-    make_cmd = (
-        f"make CFLAGS='{cflags}' -f {makefile} build_dir="
-        + build_path
-        + f" TVM_ROOT={file_dir}/../../../.."
-        + f" CODEGEN_ROOT={codegen_path}"
-        + f" STANDALONE_CRT_DIR={tvm.micro.get_standalone_crt_dir()}"
-    )
-
-    compile_log_path = os.path.join(build_path, "test_compile.log")
-    ret = subprocess_log_output(make_cmd, ".", compile_log_path)
-    assert ret == 0
-
-    # Verify that runs fine
-    run_log_path = os.path.join(build_path, "test_run.log")
-    ret = subprocess_log_output("./aot_test_runner", build_path, run_log_path)
-    assert ret == 0
-
-
-def compile_and_run_multiple_models(
-    mod_map,
-    input_list_map,
-    output_list_map,
-    interface_api,
-    use_unpacked_api,
-    use_calculated_workspaces,
-    param_map,
-    workspace_byte_alignment=8,
-):
-    """
-    This method verifies the generated source
-    """
-    base_target = "c -runtime=c --link-params --executor=aot"
-    extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api}
--unpacked-api={int(use_unpacked_api)}"
-    target = f"{base_target} {extra_target}"
-    tmp_path = utils.tempdir()
-    tmp_dir = tmp_path.temp_dir
-
-    base_path = os.path.join(tmp_dir, "test")
-    build_path = os.path.join(base_path, "build")
-    os.makedirs(build_path, exist_ok=True)
-
-    include_path = os.path.join(base_path, "include")
-    os.mkdir(include_path)
-    crt_root = tvm.micro.get_standalone_crt_dir()
-    shutil.copy2(
-        os.path.join(crt_root, "template", "crt_config-template.h"),
-        os.path.join(include_path, "crt_config.h"),
-    )
-
-    for mod_name, mod in mod_map.items():
-
-        with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
+    workspace_bytes = 0
+    for network in networks:
+        with tvm.transform.PassContext(opt_level=3, config=config):
             lib = tvm.relay.build(
-                mod, target, target_host=target, params=param_map[mod_name], mod_name=mod_name
+                network.module,
+                target,
+                target_host=target,
+                params=network.params,
+                mod_name=network.name,
             )
 
-        tar_file = os.path.join(base_path, "test.tar")
+        tar_file = os.path.join(base_path, f"{network.name}.tar")
         export_model_library_format(lib, tar_file)
         t = tarfile.open(tar_file)
         t.extractall(base_path)
 
-        input_list = input_list_map[mod_name]
-        output_list = output_list_map[mod_name]
+        if use_calculated_workspaces:
+            workspace_bytes += extract_main_workspace_sizebytes(base_path)

Review comment:
       nit: size_bytes

##########
File path: tests/python/relay/aot/aot_test_utils.py
##########
@@ -303,40 +328,34 @@ def emit_main_micro_include(main_file, mod_name):
     main_file.write(f"#include <{mangle_module_name(mod_name)}.h>\n")
 
 
-def create_main(test_name, input_map, output_list_map, output_path, interface_api, workspace_bytes):
+def create_main(test_name, networks, output_path, interface_api, workspace_bytes):

Review comment:
       nit: i think we generally have preferred `models` over `networks`

##########
File path: tests/python/relay/aot/aot_test_utils.py
##########
@@ -395,6 +410,9 @@ def compile_and_run(
     target = f"{base_target} {extra_target}"
     cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} "
 
+    if not isinstance(networks, list):
+        networks = [networks]

Review comment:
       want to `assert isinstance(networks, AOTTestNetwork)`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message