diff --git a/src/interstellar/cost_model.py b/src/interstellar/cost_model.py
index f1df0e644eb62e0f629d05529558a7f2dd9e4832..e0ccec734c69cba322e260c8d7fd60010c403698 100644
--- a/src/interstellar/cost_model.py
+++ b/src/interstellar/cost_model.py
@@ -94,7 +94,89 @@ def valid_dataflow(resource, hint):
     return True
 
 
-def get_if_access(level, point, layer, mac_capacity=1):
+def get_if_access(resource, point, layer, mac_capacity=1):
+    """
+    Returns the number of accesses to the inputs for each level.
+    """
+
+    irrelevant_loops = [le.OC, le.FX, le.FY]
+
+    num_levels = resource.buffer_levels()
+    access_counts_per_level = []
+
+    for level in range(num_levels):
+        # general idea: total number of accesses = tiling at current level * block size * num_blocks
+        # block size = tiling without the irrelevant loops
+        # num_blocks = tiling at the levels above the current level
+
+        # multiply all the tiling factors at the current level
+
+        def multiply_tiling_factors():
+            # find the innermost loop among [OX, OY, IC, ON]
+            lowest_input_loop_index = min(
+                point.loop_orders[le.OX][level],
+                point.loop_orders[le.OY][level],
+                point.loop_orders[le.IC][level],
+                point.loop_orders[le.ON][level],
+                # these are partially relevant
+                point.loop_orders[le.FX][level],
+                point.loop_orders[le.FY][level],
+            )
+
+            # we can ignore OC if it is at a lower level than the innermost input loop
+            # FX, FY can't be ignored, because they are partially relevant
+            tiling = 1
+            for i in range(le.NUM):
+                if i in [le.OC]:
+                    if point.loop_orders[i][level] > lowest_input_loop_index:
+                        # if the loop is at a higher level than the innermost input loop, we need to consider it
+                        tiling *= point.loop_blockings[i][level]
+                else:
+                    tiling *= point.loop_blockings[i][level]
+            return tiling
+
+        # remove all the irrelevant loops from the tiling of the levels below
+        def calculate_block_size():
+            block_size = 1
+
+            for lower_level in range(level - 1, -1, -1):
+                for i in range(le.NUM):
+                    if i not in irrelevant_loops:
+                        if i == le.OX:
+                            block_size *= point.loop_blockings[i][lower_level] + (
+                                point.loop_blockings[le.FX][lower_level] - 1
+                            )
+                        elif i == le.OY:
+                            block_size *= point.loop_blockings[i][lower_level] + (
+                                point.loop_blockings[le.FY][lower_level] - 1
+                            )
+                        else:
+                            block_size *= point.loop_blockings[i][lower_level]
+                        block_size *= point.loop_partitionings[i][lower_level]
+
+            return block_size
+
+        def get_num_blocks():
+            # get tiling of the levels above the current level
+            num_blocks = 1
+            for i in range(level + 1, num_levels):
+                for j in range(le.NUM):
+                    num_blocks *= point.loop_blockings[j][i]
+            return num_blocks
+
+        access_counts_per_level.append(
+            multiply_tiling_factors()
+            * calculate_block_size()
+            * get_num_blocks()
+            * resource.paras[level].count
+        )
+
+    # print("Accesses at each level: ", access_counts_per_level)
+
+    return access_counts_per_level
+
+
+def get_if_access_old(level, point, layer, mac_capacity=1):
     """
     Get per element # of access of Input at current level
 
@@ -132,7 +214,70 @@ def get_if_access(level, point, layer, mac_capacity=1):
     )
 
 
-def get_of_access(level, point, layer, mac_capacity=1):
+def get_of_access(resource, point, layer, mac_capacity=1):
+    irrelevant_loops = [le.FX, le.FY, le.IC]
+
+    num_levels = resource.buffer_levels()
+    access_counts_per_level = []
+    for level in range(num_levels):
+        # general idea: total number of accesses = tiling at current level * block size * num_blocks
+        # block size = tiling without the irrelevant loops
+        # num_blocks = tiling at the levels above the current level
+
+        # multiply all the tiling factors at the current level
+
+        def multiply_tiling_factors():
+
+            lowest_relevant_loop_index = min(
+                point.loop_orders[le.OX][level],
+                point.loop_orders[le.OY][level],
+                point.loop_orders[le.OC][level],
+                point.loop_orders[le.ON][level],
+            )
+
+            # we can ignore OX,OY,ON since they are not relevant to the weight
+            tiling = 1
+            for i in range(le.NUM):
+                if i in irrelevant_loops:
+                    if point.loop_orders[i][level] > lowest_relevant_loop_index:
+                        tiling *= point.loop_blockings[i][level]
+                else:
+                    tiling *= point.loop_blockings[i][level]
+            return tiling
+
+        # remove all the irrelevant loops from the tiling of the levels below
+        def calculate_block_size():
+            block_size = 1
+
+            for lower_level in range(level - 1, -1, -1):
+                for i in range(le.NUM):
+                    if i not in irrelevant_loops:
+                        block_size *= point.loop_blockings[i][lower_level]
+                        block_size *= point.loop_partitionings[i][lower_level]
+
+            return block_size
+
+        def get_num_blocks():
+            # get tiling of the levels above the current level
+            num_blocks = 1
+            for i in range(level + 1, num_levels):
+                for j in range(le.NUM):
+                    num_blocks *= point.loop_blockings[j][i]
+            return num_blocks
+
+        access_counts_per_level.append(
+            multiply_tiling_factors()
+            * calculate_block_size()
+            * get_num_blocks()
+            * resource.paras[level].count
+        )
+
+    # print("Accesses at each level: ", access_counts_per_level)
+
+    return access_counts_per_level
+
+
+def get_of_access_old(level, point, layer, mac_capacity=1):
     """
     Get per element # of access of Output at current level
 
@@ -180,7 +325,74 @@ def get_of_access(level, point, layer, mac_capacity=1):
     return fx_acc * fy_acc * ic_acc * fx_par * fy_par * ic_par
 
 
-def get_fl_access(level, point, layer, mac_capacity=1):
+def get_fl_access(resource, point, layer, mac_capacity=1):
+    """
+    Returns the number of accesses to the inputs for each level.
+    """
+
+    irrelevant_loops = [le.OX, le.OY, le.ON]
+
+    num_levels = resource.buffer_levels()
+    access_counts_per_level = []
+    for level in range(num_levels):
+        # general idea: total number of accesses = tiling at current level * block size * num_blocks
+        # block size = tiling without the irrelevant loops
+        # num_blocks = tiling at the levels above the current level
+
+        # multiply all the tiling factors at the current level
+
+        def multiply_tiling_factors():
+
+            lowest_relevant_loop_index = min(
+                point.loop_orders[le.FX][level],
+                point.loop_orders[le.FY][level],
+                point.loop_orders[le.IC][level],
+                point.loop_orders[le.OC][level],
+            )
+
+            # we can ignore OX,OY,ON since they are not relevant to the weight
+            tiling = 1
+            for i in range(le.NUM):
+                if i in irrelevant_loops:
+                    if point.loop_orders[i][level] > lowest_relevant_loop_index:
+                        tiling *= point.loop_blockings[i][level]
+                else:
+                    tiling *= point.loop_blockings[i][level]
+            return tiling
+
+        # remove all the irrelevant loops from the tiling of the levels below
+        def calculate_block_size():
+            block_size = 1
+
+            for lower_level in range(level - 1, -1, -1):
+                for i in range(le.NUM):
+                    if i not in irrelevant_loops:
+                        block_size *= point.loop_blockings[i][lower_level]
+                        block_size *= point.loop_partitionings[i][lower_level]
+
+            return block_size
+
+        def get_num_blocks():
+            # get tiling of the levels above the current level
+            num_blocks = 1
+            for i in range(level + 1, num_levels):
+                for j in range(le.NUM):
+                    num_blocks *= point.loop_blockings[j][i]
+            return num_blocks
+
+        access_counts_per_level.append(
+            multiply_tiling_factors()
+            * calculate_block_size()
+            * get_num_blocks()
+            * resource.paras[level].count
+        )
+
+    # print("Accesses at each level: ", access_counts_per_level)
+
+    return access_counts_per_level
+
+
+def get_fl_access_old(level, point, layer, mac_capacity=1):
     """
     Get per element # of access of Weight at current level
 
@@ -647,11 +859,12 @@ def get_access(point, layer, resource):
     mac_capacity = resource.mac_capacity
 
     access_list = []
-    for level in range(num_levels):
-        if_block_access = get_if_access(level, point, layer, mac_capacity)
-        of_block_access = 2 * get_of_access(level, point, layer, mac_capacity) - 1
-        fl_block_access = get_fl_access(level, point, layer, mac_capacity)
-        access_list.append([if_block_access, of_block_access, fl_block_access])
+
+    if_accesses = get_if_access(resource, point, layer, mac_capacity)
+    of_accesses = get_of_access(resource, point, layer, mac_capacity)
+    fl_accesses = get_fl_access(resource, point, layer, mac_capacity)
+
+    access_list = list(zip(if_accesses, of_accesses, fl_accesses))
 
     # para_mode = [e.access_mode for i, e in enumerate(resource.paras) if e.access_mode != 0]
     para_mode_level = [i for i, e in enumerate(resource.paras) if e.access_mode != 0]
@@ -1007,8 +1220,7 @@ def get_array_level_cost(
 
     total_cost = 0
     for i in range(len(level_access)):
-        buffer_access = list(map(mul, level_access[i], layer_size))
-        total_cost += sum(buffer_access) * level_cost[i]
+        total_cost += level_access[i] * level_cost[i]
 
     if verbose >= 3:
         print("Level ", level, " array level access: ", level_access)
@@ -1034,23 +1246,22 @@ def get_array_and_curr_level_cost(resource, point, layer, level, verbose=False):
 
     [if_access, of_access, fl_access] = level_access
 
-    buffer_level_access = [if_access, 2 * of_access - 1, fl_access]
-    total_buffer_access = list(map(mul, buffer_level_access, layer_size))
+    buffer_level_access = [if_access, of_access, fl_access]
     # level_cost = sum(total_buffer_access) * resource.access_cost[level]
     level_cost = 0
-    for i in range(len(total_buffer_access)):
+    for i in range(len(buffer_level_access)):
         index = resource.memory_partitions[level][i]
         if index is not None:
-            level_cost += total_buffer_access[i] * resource.access_cost[level][index]
+            level_cost += buffer_level_access[i] * resource.access_cost[level][index]
     # operand_costs = [access_cost * num_accesses for access_cost,num_accesses in zip(total_buffer_access,resource.access_cost[level]) ]
     # level_cost = sum(operand_costs)
 
     if verbose >= 3:
         print("Level ", level, " access: ", buffer_level_access)
 
-    level_cost += get_array_level_cost(
-        resource, point, layer_size, level - 1, level_access, verbose
-    )
+    # level_cost += get_array_level_cost(
+    #     resource, point, layer_size, level - 1, level_access, verbose
+    # )
 
     return level_cost
 
@@ -1065,23 +1276,20 @@ def get_level_cost(resource, point, layer, level, verbose=False):
     layer_size = get_layer_size(layer)
     mac_capacity = resource.mac_capacity
 
-    level_access = [
-        get_if_access(level, point, layer, mac_capacity),
-        2 * get_of_access(level, point, layer, mac_capacity) - 1,
-        get_fl_access(level, point, layer, mac_capacity),
-    ]
+    if_accesses = get_if_access(resource, point, layer, mac_capacity)
+    of_accesses = get_of_access(resource, point, layer, mac_capacity)
+    fl_accesses = get_fl_access(resource, point, layer, mac_capacity)
+
+    buffer_access = list(zip(if_accesses, of_accesses, fl_accesses))
 
-    buffer_access = list(map(mul, level_access, layer_size))
     # Inputs, weights, and outputs may have different costs
     # level_cost = sum(buffer_access) * resource.access_cost[level]
     level_cost = 0
-    for i in range(len(buffer_access)):
-        index = resource.memory_partitions[level][i]
-        if index is not None:
-            level_cost += buffer_access[i] * resource.access_cost[level][index]
-    # resouce.memory_partitions
-    # operand_costs = [access_cost * num_accesses for access_cost,num_accesses in zip(buffer_access,resource.access_cost[level]) ]
-    # level_cost = sum(operand_costs)
+    for i in range(3):
+        memory_partition = resource.memory_partitions[level][i]
+        level_cost += (
+            buffer_access[level][i] * resource.access_cost[level][memory_partition]
+        )
 
     if verbose >= 3:
         print("Level", level, " access: ", level_access)
@@ -1179,7 +1387,6 @@ def get_cost(resource, point, layer, verbose=False):
     )
 
     access_list, array_cost = get_access(point, layer, resource)
-    layer_size = get_layer_size(layer)
 
     total_access_cost = get_total_access_cost(resource, array_cost)
     assert len(total_access_cost) == len(access_list)
@@ -1188,12 +1395,12 @@ def get_cost(resource, point, layer, verbose=False):
     for i in range(len(total_access_cost)):
         """List of total access of each buffer at level i"""
         if not isinstance(access_list[i][0], list):
-            buffer_access = list(map(mul, access_list[i], layer_size))
-            total_cost += sum(buffer_access) * total_access_cost[i][0]
+            total_cost += sum(
+                [access * total_access_cost[i][0] for access in access_list[i]]
+            )
         else:
             for j in range(len(access_list[i])):
-                buffer_access = list(map(mul, access_list[i][j], layer_size))
-                total_cost += sum(buffer_access) * total_access_cost[i][j]
+                total_cost += access_list[i][j] * total_access_cost[i][j]
 
     if verbose:
         # print("total_access_cost", total_access_cost)
@@ -1266,4 +1473,4 @@ def get_cost(resource, point, layer, verbose=False):
         # print('total cost: ', total_cost)
 
     # return total_cost
-    return total_cost, total_access_cost, access_list, layer_size
+    return total_cost, total_access_cost, access_list