From b34fb0ca7eb2652f18817cedd861ae08f974a4a9 Mon Sep 17 00:00:00 2001
From: sgauthamr2001 <sgauthamr2001@gmail.com>
Date: Wed, 8 Jan 2025 21:49:22 -0800
Subject: [PATCH] Revert "Fix access count logic, validated with simulation"

This reverts commit 46994104de43afd973865c11dcb57cdaa14a6e39.
---
 src/interstellar/cost_model.py | 279 +++++----------------------------
 1 file changed, 36 insertions(+), 243 deletions(-)

diff --git a/src/interstellar/cost_model.py b/src/interstellar/cost_model.py
index e0ccec7..f1df0e6 100644
--- a/src/interstellar/cost_model.py
+++ b/src/interstellar/cost_model.py
@@ -94,89 +94,7 @@ def valid_dataflow(resource, hint):
     return True
 
 
-def get_if_access(resource, point, layer, mac_capacity=1):
-    """
-    Returns the number of accesses to the inputs for each level.
-    """
-
-    irrelevant_loops = [le.OC, le.FX, le.FY]
-
-    num_levels = resource.buffer_levels()
-    access_counts_per_level = []
-
-    for level in range(num_levels):
-        # general idea: total number of accesses = tiling at current level * block size * num_blocks
-        # block size = tiling without the irrelevant loops
-        # num_blocks = tiling at the levels above the current level
-
-        # multiply all the tiling factors at the current level
-
-        def multiply_tiling_factors():
-            # find the innermost loop among [OX, OY, IC, ON]
-            lowest_input_loop_index = min(
-                point.loop_orders[le.OX][level],
-                point.loop_orders[le.OY][level],
-                point.loop_orders[le.IC][level],
-                point.loop_orders[le.ON][level],
-                # these are partially relevant
-                point.loop_orders[le.FX][level],
-                point.loop_orders[le.FY][level],
-            )
-
-            # we can ignore OC if it is at a lower level than the innermost input loop
-            # FX, FY can't be ignored, because they are partially relevant
-            tiling = 1
-            for i in range(le.NUM):
-                if i in [le.OC]:
-                    if point.loop_orders[i][level] > lowest_input_loop_index:
-                        # if the loop is at a higher level than the innermost input loop, we need to consider it
-                        tiling *= point.loop_blockings[i][level]
-                else:
-                    tiling *= point.loop_blockings[i][level]
-            return tiling
-
-        # remove all the irrelevant loops from the tiling of the levels below
-        def calculate_block_size():
-            block_size = 1
-
-            for lower_level in range(level - 1, -1, -1):
-                for i in range(le.NUM):
-                    if i not in irrelevant_loops:
-                        if i == le.OX:
-                            block_size *= point.loop_blockings[i][lower_level] + (
-                                point.loop_blockings[le.FX][lower_level] - 1
-                            )
-                        elif i == le.OY:
-                            block_size *= point.loop_blockings[i][lower_level] + (
-                                point.loop_blockings[le.FY][lower_level] - 1
-                            )
-                        else:
-                            block_size *= point.loop_blockings[i][lower_level]
-                        block_size *= point.loop_partitionings[i][lower_level]
-
-            return block_size
-
-        def get_num_blocks():
-            # get tiling of the levels above the current level
-            num_blocks = 1
-            for i in range(level + 1, num_levels):
-                for j in range(le.NUM):
-                    num_blocks *= point.loop_blockings[j][i]
-            return num_blocks
-
-        access_counts_per_level.append(
-            multiply_tiling_factors()
-            * calculate_block_size()
-            * get_num_blocks()
-            * resource.paras[level].count
-        )
-
-    # print("Accesses at each level: ", access_counts_per_level)
-
-    return access_counts_per_level
-
-
-def get_if_access_old(level, point, layer, mac_capacity=1):
+def get_if_access(level, point, layer, mac_capacity=1):
     """
     Get per element # of access of Input at current level
 
@@ -214,70 +132,7 @@ def get_if_access_old(level, point, layer, mac_capacity=1):
     )
 
 
-def get_of_access(resource, point, layer, mac_capacity=1):
-    irrelevant_loops = [le.FX, le.FY, le.IC]
-
-    num_levels = resource.buffer_levels()
-    access_counts_per_level = []
-    for level in range(num_levels):
-        # general idea: total number of accesses = tiling at current level * block size * num_blocks
-        # block size = tiling without the irrelevant loops
-        # num_blocks = tiling at the levels above the current level
-
-        # multiply all the tiling factors at the current level
-
-        def multiply_tiling_factors():
-
-            lowest_relevant_loop_index = min(
-                point.loop_orders[le.OX][level],
-                point.loop_orders[le.OY][level],
-                point.loop_orders[le.OC][level],
-                point.loop_orders[le.ON][level],
-            )
-
-            # we can ignore OX,OY,ON since they are not relevant to the weight
-            tiling = 1
-            for i in range(le.NUM):
-                if i in irrelevant_loops:
-                    if point.loop_orders[i][level] > lowest_relevant_loop_index:
-                        tiling *= point.loop_blockings[i][level]
-                else:
-                    tiling *= point.loop_blockings[i][level]
-            return tiling
-
-        # remove all the irrelevant loops from the tiling of the levels below
-        def calculate_block_size():
-            block_size = 1
-
-            for lower_level in range(level - 1, -1, -1):
-                for i in range(le.NUM):
-                    if i not in irrelevant_loops:
-                        block_size *= point.loop_blockings[i][lower_level]
-                        block_size *= point.loop_partitionings[i][lower_level]
-
-            return block_size
-
-        def get_num_blocks():
-            # get tiling of the levels above the current level
-            num_blocks = 1
-            for i in range(level + 1, num_levels):
-                for j in range(le.NUM):
-                    num_blocks *= point.loop_blockings[j][i]
-            return num_blocks
-
-        access_counts_per_level.append(
-            multiply_tiling_factors()
-            * calculate_block_size()
-            * get_num_blocks()
-            * resource.paras[level].count
-        )
-
-    # print("Accesses at each level: ", access_counts_per_level)
-
-    return access_counts_per_level
-
-
-def get_of_access_old(level, point, layer, mac_capacity=1):
+def get_of_access(level, point, layer, mac_capacity=1):
     """
     Get per element # of access of Output at current level
 
@@ -325,74 +180,7 @@ def get_of_access_old(level, point, layer, mac_capacity=1):
     return fx_acc * fy_acc * ic_acc * fx_par * fy_par * ic_par
 
 
-def get_fl_access(resource, point, layer, mac_capacity=1):
-    """
-    Returns the number of accesses to the inputs for each level.
-    """
-
-    irrelevant_loops = [le.OX, le.OY, le.ON]
-
-    num_levels = resource.buffer_levels()
-    access_counts_per_level = []
-    for level in range(num_levels):
-        # general idea: total number of accesses = tiling at current level * block size * num_blocks
-        # block size = tiling without the irrelevant loops
-        # num_blocks = tiling at the levels above the current level
-
-        # multiply all the tiling factors at the current level
-
-        def multiply_tiling_factors():
-
-            lowest_relevant_loop_index = min(
-                point.loop_orders[le.FX][level],
-                point.loop_orders[le.FY][level],
-                point.loop_orders[le.IC][level],
-                point.loop_orders[le.OC][level],
-            )
-
-            # we can ignore OX,OY,ON since they are not relevant to the weight
-            tiling = 1
-            for i in range(le.NUM):
-                if i in irrelevant_loops:
-                    if point.loop_orders[i][level] > lowest_relevant_loop_index:
-                        tiling *= point.loop_blockings[i][level]
-                else:
-                    tiling *= point.loop_blockings[i][level]
-            return tiling
-
-        # remove all the irrelevant loops from the tiling of the levels below
-        def calculate_block_size():
-            block_size = 1
-
-            for lower_level in range(level - 1, -1, -1):
-                for i in range(le.NUM):
-                    if i not in irrelevant_loops:
-                        block_size *= point.loop_blockings[i][lower_level]
-                        block_size *= point.loop_partitionings[i][lower_level]
-
-            return block_size
-
-        def get_num_blocks():
-            # get tiling of the levels above the current level
-            num_blocks = 1
-            for i in range(level + 1, num_levels):
-                for j in range(le.NUM):
-                    num_blocks *= point.loop_blockings[j][i]
-            return num_blocks
-
-        access_counts_per_level.append(
-            multiply_tiling_factors()
-            * calculate_block_size()
-            * get_num_blocks()
-            * resource.paras[level].count
-        )
-
-    # print("Accesses at each level: ", access_counts_per_level)
-
-    return access_counts_per_level
-
-
-def get_fl_access_old(level, point, layer, mac_capacity=1):
+def get_fl_access(level, point, layer, mac_capacity=1):
     """
     Get per element # of access of Weight at current level
 
@@ -859,12 +647,11 @@ def get_access(point, layer, resource):
     mac_capacity = resource.mac_capacity
 
     access_list = []
-
-    if_accesses = get_if_access(resource, point, layer, mac_capacity)
-    of_accesses = get_of_access(resource, point, layer, mac_capacity)
-    fl_accesses = get_fl_access(resource, point, layer, mac_capacity)
-
-    access_list = list(zip(if_accesses, of_accesses, fl_accesses))
+    for level in range(num_levels):
+        if_block_access = get_if_access(level, point, layer, mac_capacity)
+        of_block_access = 2 * get_of_access(level, point, layer, mac_capacity) - 1
+        fl_block_access = get_fl_access(level, point, layer, mac_capacity)
+        access_list.append([if_block_access, of_block_access, fl_block_access])
 
     # para_mode = [e.access_mode for i, e in enumerate(resource.paras) if e.access_mode != 0]
     para_mode_level = [i for i, e in enumerate(resource.paras) if e.access_mode != 0]
@@ -1220,7 +1007,8 @@ def get_array_level_cost(
 
     total_cost = 0
     for i in range(len(level_access)):
-        total_cost += level_access[i] * level_cost[i]
+        buffer_access = list(map(mul, level_access[i], layer_size))
+        total_cost += sum(buffer_access) * level_cost[i]
 
     if verbose >= 3:
         print("Level ", level, " array level access: ", level_access)
@@ -1246,22 +1034,23 @@ def get_array_and_curr_level_cost(resource, point, layer, level, verbose=False):
 
     [if_access, of_access, fl_access] = level_access
 
-    buffer_level_access = [if_access, of_access, fl_access]
+    buffer_level_access = [if_access, 2 * of_access - 1, fl_access]
+    total_buffer_access = list(map(mul, buffer_level_access, layer_size))
     # level_cost = sum(total_buffer_access) * resource.access_cost[level]
     level_cost = 0
-    for i in range(len(buffer_level_access)):
+    for i in range(len(total_buffer_access)):
         index = resource.memory_partitions[level][i]
         if index is not None:
-            level_cost += buffer_level_access[i] * resource.access_cost[level][index]
+            level_cost += total_buffer_access[i] * resource.access_cost[level][index]
     # operand_costs = [access_cost * num_accesses for access_cost,num_accesses in zip(total_buffer_access,resource.access_cost[level]) ]
     # level_cost = sum(operand_costs)
 
     if verbose >= 3:
         print("Level ", level, " access: ", buffer_level_access)
 
-    # level_cost += get_array_level_cost(
-    #     resource, point, layer_size, level - 1, level_access, verbose
-    # )
+    level_cost += get_array_level_cost(
+        resource, point, layer_size, level - 1, level_access, verbose
+    )
 
     return level_cost
 
@@ -1276,20 +1065,23 @@ def get_level_cost(resource, point, layer, level, verbose=False):
     layer_size = get_layer_size(layer)
     mac_capacity = resource.mac_capacity
 
-    if_accesses = get_if_access(resource, point, layer, mac_capacity)
-    of_accesses = get_of_access(resource, point, layer, mac_capacity)
-    fl_accesses = get_fl_access(resource, point, layer, mac_capacity)
-
-    buffer_access = list(zip(if_accesses, of_accesses, fl_accesses))
+    level_access = [
+        get_if_access(level, point, layer, mac_capacity),
+        2 * get_of_access(level, point, layer, mac_capacity) - 1,
+        get_fl_access(level, point, layer, mac_capacity),
+    ]
 
+    buffer_access = list(map(mul, level_access, layer_size))
     # Inputs, weights, and outputs may have different costs
     # level_cost = sum(buffer_access) * resource.access_cost[level]
     level_cost = 0
-    for i in range(3):
-        memory_partition = resource.memory_partitions[level][i]
-        level_cost += (
-            buffer_access[level][i] * resource.access_cost[level][memory_partition]
-        )
+    for i in range(len(buffer_access)):
+        index = resource.memory_partitions[level][i]
+        if index is not None:
+            level_cost += buffer_access[i] * resource.access_cost[level][index]
+    # resouce.memory_partitions
+    # operand_costs = [access_cost * num_accesses for access_cost,num_accesses in zip(buffer_access,resource.access_cost[level]) ]
+    # level_cost = sum(operand_costs)
 
     if verbose >= 3:
         print("Level", level, " access: ", level_access)
@@ -1387,6 +1179,7 @@ def get_cost(resource, point, layer, verbose=False):
     )
 
     access_list, array_cost = get_access(point, layer, resource)
+    layer_size = get_layer_size(layer)
 
     total_access_cost = get_total_access_cost(resource, array_cost)
     assert len(total_access_cost) == len(access_list)
@@ -1395,12 +1188,12 @@ def get_cost(resource, point, layer, verbose=False):
     for i in range(len(total_access_cost)):
         """List of total access of each buffer at level i"""
         if not isinstance(access_list[i][0], list):
-            total_cost += sum(
-                [access * total_access_cost[i][0] for access in access_list[i]]
-            )
+            buffer_access = list(map(mul, access_list[i], layer_size))
+            total_cost += sum(buffer_access) * total_access_cost[i][0]
         else:
             for j in range(len(access_list[i])):
-                total_cost += access_list[i][j] * total_access_cost[i][j]
+                buffer_access = list(map(mul, access_list[i][j], layer_size))
+                total_cost += sum(buffer_access) * total_access_cost[i][j]
 
     if verbose:
         # print("total_access_cost", total_access_cost)
@@ -1473,4 +1266,4 @@ def get_cost(resource, point, layer, verbose=False):
         # print('total cost: ', total_cost)
 
     # return total_cost
-    return total_cost, total_access_cost, access_list
+    return total_cost, total_access_cost, access_list, layer_size
-- 
GitLab