diff --git a/src/interstellar/cost_model.py b/src/interstellar/cost_model.py index f1df0e644eb62e0f629d05529558a7f2dd9e4832..e0ccec734c69cba322e260c8d7fd60010c403698 100644 --- a/src/interstellar/cost_model.py +++ b/src/interstellar/cost_model.py @@ -94,7 +94,89 @@ def valid_dataflow(resource, hint): return True -def get_if_access(level, point, layer, mac_capacity=1): +def get_if_access(resource, point, layer, mac_capacity=1): + """ + Returns the number of accesses to the inputs for each level. + """ + + irrelevant_loops = [le.OC, le.FX, le.FY] + + num_levels = resource.buffer_levels() + access_counts_per_level = [] + + for level in range(num_levels): + # general idea: total number of accesses = tiling at current level * block size * num_blocks + # block size = tiling without the irrelevant loops + # num_blocks = tiling at the levels above the current level + + # multiply all the tiling factors at the current level + + def multiply_tiling_factors(): + # find the innermost loop among [OX, OY, IC, ON] + lowest_input_loop_index = min( + point.loop_orders[le.OX][level], + point.loop_orders[le.OY][level], + point.loop_orders[le.IC][level], + point.loop_orders[le.ON][level], + # these are partially relevant + point.loop_orders[le.FX][level], + point.loop_orders[le.FY][level], + ) + + # we can ignore OC if it is at a lower level than the innermost input loop + # FX, FY can't be ignored, because they are partially relevant + tiling = 1 + for i in range(le.NUM): + if i in [le.OC]: + if point.loop_orders[i][level] > lowest_input_loop_index: + # if the loop is at a higher level than the innermost input loop, we need to consider it + tiling *= point.loop_blockings[i][level] + else: + tiling *= point.loop_blockings[i][level] + return tiling + + # remove all the irrelevant loops from the tiling of the levels below + def calculate_block_size(): + block_size = 1 + + for lower_level in range(level - 1, -1, -1): + for i in range(le.NUM): + if i not in irrelevant_loops: + if i == le.OX: + block_size *= point.loop_blockings[i][lower_level] + ( + point.loop_blockings[le.FX][lower_level] - 1 + ) + elif i == le.OY: + block_size *= point.loop_blockings[i][lower_level] + ( + point.loop_blockings[le.FY][lower_level] - 1 + ) + else: + block_size *= point.loop_blockings[i][lower_level] + block_size *= point.loop_partitionings[i][lower_level] + + return block_size + + def get_num_blocks(): + # get tiling of the levels above the current level + num_blocks = 1 + for i in range(level + 1, num_levels): + for j in range(le.NUM): + num_blocks *= point.loop_blockings[j][i] + return num_blocks + + access_counts_per_level.append( + multiply_tiling_factors() + * calculate_block_size() + * get_num_blocks() + * resource.paras[level].count + ) + + # print("Accesses at each level: ", access_counts_per_level) + + return access_counts_per_level + + +def get_if_access_old(level, point, layer, mac_capacity=1): """ Get per element # of access of Input at current level @@ -132,7 +214,70 @@ def get_if_access(level, point, layer, mac_capacity=1): ) -def get_of_access(level, point, layer, mac_capacity=1): +def get_of_access(resource, point, layer, mac_capacity=1): + irrelevant_loops = [le.FX, le.FY, le.IC] + + num_levels = resource.buffer_levels() + access_counts_per_level = [] + for level in range(num_levels): + # general idea: total number of accesses = tiling at current level * block size * num_blocks + # block size = tiling without the irrelevant loops + # num_blocks = tiling at the levels above the current level + + # multiply all the tiling factors at the current level + + def multiply_tiling_factors(): + + lowest_relevant_loop_index = min( + point.loop_orders[le.OX][level], + point.loop_orders[le.OY][level], + point.loop_orders[le.OC][level], + point.loop_orders[le.ON][level], + ) + + # we can ignore OX,OY,ON since they are not relevant to the weight + tiling = 1 + for i in range(le.NUM): + if i in irrelevant_loops: + if point.loop_orders[i][level] > lowest_relevant_loop_index: + tiling *= point.loop_blockings[i][level] + else: + tiling *= point.loop_blockings[i][level] + return tiling + + # remove all the irrelevant loops from the tiling of the levels below + def calculate_block_size(): + block_size = 1 + + for lower_level in range(level - 1, -1, -1): + for i in range(le.NUM): + if i not in irrelevant_loops: + block_size *= point.loop_blockings[i][lower_level] + block_size *= point.loop_partitionings[i][lower_level] + + return block_size + + def get_num_blocks(): + # get tiling of the levels above the current level + num_blocks = 1 + for i in range(level + 1, num_levels): + for j in range(le.NUM): + num_blocks *= point.loop_blockings[j][i] + return num_blocks + + access_counts_per_level.append( + multiply_tiling_factors() + * calculate_block_size() + * get_num_blocks() + * resource.paras[level].count + ) + + # print("Accesses at each level: ", access_counts_per_level) + + return access_counts_per_level + + +def get_of_access_old(level, point, layer, mac_capacity=1): """ Get per element # of access of Output at current level @@ -180,7 +325,74 @@ def get_of_access(level, point, layer, mac_capacity=1): return fx_acc * fy_acc * ic_acc * fx_par * fy_par * ic_par -def get_fl_access(level, point, layer, mac_capacity=1): +def get_fl_access(resource, point, layer, mac_capacity=1): + """ + Returns the number of accesses to the inputs for each level. + """ + + irrelevant_loops = [le.OX, le.OY, le.ON] + + num_levels = resource.buffer_levels() + access_counts_per_level = [] + for level in range(num_levels): + # general idea: total number of accesses = tiling at current level * block size * num_blocks + # block size = tiling without the irrelevant loops + # num_blocks = tiling at the levels above the current level + + # multiply all the tiling factors at the current level + + def multiply_tiling_factors(): + + lowest_relevant_loop_index = min( + point.loop_orders[le.FX][level], + point.loop_orders[le.FY][level], + point.loop_orders[le.IC][level], + point.loop_orders[le.OC][level], + ) + + # we can ignore OX,OY,ON since they are not relevant to the weight + tiling = 1 + for i in range(le.NUM): + if i in irrelevant_loops: + if point.loop_orders[i][level] > lowest_relevant_loop_index: + tiling *= point.loop_blockings[i][level] + else: + tiling *= point.loop_blockings[i][level] + return tiling + + # remove all the irrelevant loops from the tiling of the levels below + def calculate_block_size(): + block_size = 1 + + for lower_level in range(level - 1, -1, -1): + for i in range(le.NUM): + if i not in irrelevant_loops: + block_size *= point.loop_blockings[i][lower_level] + block_size *= point.loop_partitionings[i][lower_level] + + return block_size + + def get_num_blocks(): + # get tiling of the levels above the current level + num_blocks = 1 + for i in range(level + 1, num_levels): + for j in range(le.NUM): + num_blocks *= point.loop_blockings[j][i] + return num_blocks + + access_counts_per_level.append( + multiply_tiling_factors() + * calculate_block_size() + * get_num_blocks() + * resource.paras[level].count + ) + + # print("Accesses at each level: ", access_counts_per_level) + + return access_counts_per_level + + +def get_fl_access_old(level, point, layer, mac_capacity=1): """ Get per element # of access of Weight at current level @@ -647,11 +859,12 @@ def get_access(point, layer, resource): mac_capacity = resource.mac_capacity access_list = [] - for level in range(num_levels): - if_block_access = get_if_access(level, point, layer, mac_capacity) - of_block_access = 2 * get_of_access(level, point, layer, mac_capacity) - 1 - fl_block_access = get_fl_access(level, point, layer, mac_capacity) - access_list.append([if_block_access, of_block_access, fl_block_access]) + + if_accesses = get_if_access(resource, point, layer, mac_capacity) + of_accesses = get_of_access(resource, point, layer, mac_capacity) + fl_accesses = get_fl_access(resource, point, layer, mac_capacity) + + access_list = list(zip(if_accesses, of_accesses, fl_accesses)) # para_mode = [e.access_mode for i, e in enumerate(resource.paras) if e.access_mode != 0] para_mode_level = [i for i, e in enumerate(resource.paras) if e.access_mode != 0] @@ -1007,8 +1220,7 @@ def get_array_level_cost( total_cost = 0 for i in range(len(level_access)): - buffer_access = list(map(mul, level_access[i], layer_size)) - total_cost += sum(buffer_access) * level_cost[i] + total_cost += level_access[i] * level_cost[i] if verbose >= 3: print("Level ", level, " array level access: ", level_access) @@ -1034,23 +1246,22 @@ def get_array_and_curr_level_cost(resource, point, layer, level, verbose=False): [if_access, of_access, fl_access] = level_access - buffer_level_access = [if_access, 2 * of_access - 1, fl_access] - total_buffer_access = list(map(mul, buffer_level_access, layer_size)) + buffer_level_access = [if_access, of_access, fl_access] # level_cost = sum(total_buffer_access) * resource.access_cost[level] level_cost = 0 - for i in range(len(total_buffer_access)): + for i in range(len(buffer_level_access)): index = resource.memory_partitions[level][i] if index is not None: - level_cost += total_buffer_access[i] * resource.access_cost[level][index] + level_cost += buffer_level_access[i] * resource.access_cost[level][index] # operand_costs = [access_cost * num_accesses for access_cost,num_accesses in zip(total_buffer_access,resource.access_cost[level]) ] # level_cost = sum(operand_costs) if verbose >= 3: print("Level ", level, " access: ", buffer_level_access) - level_cost += get_array_level_cost( - resource, point, layer_size, level - 1, level_access, verbose - ) + # level_cost += get_array_level_cost( + # resource, point, layer_size, level - 1, level_access, verbose + # ) return level_cost @@ -1065,23 +1276,20 @@ def get_level_cost(resource, point, layer, level, verbose=False): layer_size = get_layer_size(layer) mac_capacity = resource.mac_capacity - level_access = [ - get_if_access(level, point, layer, mac_capacity), - 2 * get_of_access(level, point, layer, mac_capacity) - 1, - get_fl_access(level, point, layer, mac_capacity), - ] + if_accesses = get_if_access(resource, point, layer, mac_capacity) + of_accesses = get_of_access(resource, point, layer, mac_capacity) + fl_accesses = get_fl_access(resource, point, layer, mac_capacity) + + buffer_access = list(zip(if_accesses, of_accesses, fl_accesses)) - buffer_access = list(map(mul, level_access, layer_size)) # Inputs, weights, and outputs may have different costs # level_cost = sum(buffer_access) * resource.access_cost[level] level_cost = 0 - for i in range(len(buffer_access)): - index = resource.memory_partitions[level][i] - if index is not None: - level_cost += buffer_access[i] * resource.access_cost[level][index] - # resouce.memory_partitions - # operand_costs = [access_cost * num_accesses for access_cost,num_accesses in zip(buffer_access,resource.access_cost[level]) ] - # level_cost = sum(operand_costs) + for i in range(3): + memory_partition = resource.memory_partitions[level][i] + level_cost += ( + buffer_access[level][i] * resource.access_cost[level][memory_partition] + ) if verbose >= 3: print("Level", level, " access: ", level_access) @@ -1179,7 +1387,6 @@ def get_cost(resource, point, layer, verbose=False): ) access_list, array_cost = get_access(point, layer, resource) - layer_size = get_layer_size(layer) total_access_cost = get_total_access_cost(resource, array_cost) assert len(total_access_cost) == len(access_list) @@ -1188,12 +1395,12 @@ def get_cost(resource, point, layer, verbose=False): for i in range(len(total_access_cost)): """List of total access of each buffer at level i""" if not isinstance(access_list[i][0], list): - buffer_access = list(map(mul, access_list[i], layer_size)) - total_cost += sum(buffer_access) * total_access_cost[i][0] + total_cost += sum( + [access * total_access_cost[i][0] for access in access_list[i]] + ) else: for j in range(len(access_list[i])): - buffer_access = list(map(mul, access_list[i][j], layer_size)) - total_cost += sum(buffer_access) * total_access_cost[i][j] + total_cost += access_list[i][j] * total_access_cost[i][j] if verbose: # print("total_access_cost", total_access_cost) @@ -1266,4 +1473,4 @@ def get_cost(resource, point, layer, verbose=False): # print('total cost: ', total_cost) # return total_cost - return total_cost, total_access_cost, access_list, layer_size + return total_cost, total_access_cost, access_list