diff --git a/src/interstellar/mapping_point_generator.py b/src/interstellar/mapping_point_generator.py index d362969868f7befb3077c1541bef1392307f1468..3643b1e4414c64d4e89364c0d9b1c1a3ec05f4fa 100644 --- a/src/interstellar/mapping_point_generator.py +++ b/src/interstellar/mapping_point_generator.py @@ -243,16 +243,7 @@ def loop_tile_with_hint(tile_permutations, loop_extent, num_level, loop_hint): def loop_tile(loop_extent, num_level, loop_hint=None): tile_permutations = [] - - # check if the hint specifies the blocking at any level - is_valid_hint = False - if loop_hint is not None: - for level in range(num_level): - if loop_hint[level] is not None: - is_valid_hint = True - break - - if not is_valid_hint: + if not loop_hint: recursive_tile(tile_permutations, [], loop_extent, 0, num_level) else: loop_tile_with_hint(tile_permutations, loop_extent, num_level, loop_hint) @@ -260,30 +251,12 @@ def loop_tile(loop_extent, num_level, loop_hint=None): return tile_permutations -def opt_valid_blocking(blocking_cache, resource, layer, blocking, schedule=None): - """ - Checks if a given blocking configuration is valid for a specific layer and resource setup. - """ +def opt_valid_blocking(blocking_cache, resource, layer, blocking): num_levels = resource.buffer_levels() blocking_tuple = list(zip(*blocking)) dummy_partitioning = [(1,) * num_levels] * le.NUM dummy_mapping_point = MappingPoint(None, list(blocking), dummy_partitioning) - """ - Check if blocking fits the blocking constraints in the schedule hint - """ - if schedule is not None and schedule.schedule_hint is not None: - for loop_index in range(le.NUM): - if loop_index in schedule.schedule_hint: - for level in range(num_levels): - if schedule.schedule_hint[loop_index][level] is not None: - if ( - schedule.schedule_hint[loop_index][level][1] is not None - and blocking_tuple[level][loop_index] - != schedule.schedule_hint[loop_index][level][1] - ): - return False - """ Use cache to compute valid of first level """ @@ -331,7 +304,7 @@ def blocking_generator_function(resource, layer, schedule=None, verbose=False): for tile in itertools.product(*all_tile_permutations): # TODO here the generated is a list of lists, not a list of tuples # if cost_model.valid_blocking_size(resource, dummy_mapping_point, layer): - if opt_valid_blocking(blocking_cache, resource, layer, tile, schedule): + if opt_valid_blocking(blocking_cache, resource, layer, tile): yield list(tile) @@ -774,7 +747,7 @@ def blocking_partitioning_generator_function(resource, layer, schedule, verbose= print("") -def opt_get_best_loop_order(resource, layer, point, schedule=None, verbose=False): +def opt_get_best_loop_order(resource, layer, point, verbose=False): """ [HW template right now: systolic array] @@ -819,43 +792,6 @@ def opt_get_best_loop_order(resource, layer, point, schedule=None, verbose=False mapping_point = MappingPoint( list(zip(*dummy_loop_order)), blocking, partitioning, para_dim ) - - # check mapping point fits the schedule hint constraint - valid_mapping_point = True - if schedule is not None and schedule.schedule_hint is not None: - for loop_index in range(le.NUM): - # if there's a blocking of 1, we can ignore the loop order constraint - if blocking[loop_index][level] == 1: - continue - - if loop_index in schedule.schedule_hint: - if ( - schedule.schedule_hint[loop_index][level] is not None - and schedule.schedule_hint[loop_index][level][0] is not None - ): - loop_index_constraint = schedule.schedule_hint[loop_index][ - level - ][0] - - # for negative value, it means we should count from the outermost loop - # e.g. -1 means the outermost loop - if loop_index_constraint < 0: - # find the max loop index (that is not 6) in the current level order - max_loop_index = max( - [i if i != 6 else -1 for i in curr_level_order] - ) - loop_index_constraint = ( - max_loop_index + 1 + loop_index_constraint - ) - - # if the loop index is not the same as the constraint, then it's not a valid mapping point - if curr_level_order[loop_index] != loop_index_constraint: - valid_mapping_point = False - break - - if not valid_mapping_point: - continue - if ( level <= 0 or resource.paras[level - 1].count <= 1 @@ -933,7 +869,7 @@ def opt_mapping_point_generator_function(resource, layer, schedule=None, verbose dummy_mapping_point = MappingPoint(None, blocking, partitioning, para_dim) # print "blocking_partitioning: ", blocking_partitioning cost, loop_order = opt_get_best_loop_order( - resource, layer, dummy_mapping_point, schedule, verbose + resource, layer, dummy_mapping_point, verbose ) if cost < smallest_cost: smallest_cost = cost