From a4ae5dbbdab58820bd824737e54f9714c40fafb7 Mon Sep 17 00:00:00 2001 From: sgauthamr2001 <sgauthamr2001@gmail.com> Date: Sat, 4 Jan 2025 23:55:19 -0800 Subject: [PATCH] Add pretty printing of loop nest --- src/interstellar/utils.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/interstellar/utils.py b/src/interstellar/utils.py index 725917d..21081ab 100644 --- a/src/interstellar/utils.py +++ b/src/interstellar/utils.py @@ -2,7 +2,7 @@ from . import loop_enum as le from . import buffer_enum as be -def print_loop_nest(point): +def get_loop_nest(point): loop_orders = list(zip(*point.loop_orders)) loop_blockings = list(zip(*point.loop_blockings)) loop_partitionings = list(zip(*point.loop_partitionings)) @@ -24,6 +24,41 @@ def print_loop_nest(point): order_lists.append(order_list) + return order_lists, para_dims + + +def print_tiling(point): + order_lists, _ = get_loop_nest(point) + + bottom_up_prints = [] + + for level in order_lists: + for loops in level: + if loops is not None: + if loops[2] == 1: + bottom_up_prints.append( + f"for {loops[0]} in range({int(loops[1])}):" + ) + else: + bottom_up_prints.append( + f"parallel_for {loops[0]} in range({int(loops[2])}):" + ) + else: + bottom_up_prints.append("") + break + + space_count = 0 + for i in range(len(bottom_up_prints) - 1, -1, -1): + if bottom_up_prints[i] == "": + print(bottom_up_prints[i]) + else: + print((" " * space_count) + bottom_up_prints[i]) + space_count += 2 + + +def print_loop_nest(point): + order_lists, para_dims = get_loop_nest(point) + print(order_lists, para_dims) -- GitLab