Skip to content
Snippets Groups Projects
Commit f5198e78 authored by sgauthamr2001's avatar sgauthamr2001
Browse files

Revert "Use schedule hint to constrain loop blocking and loop order"

This reverts commit ee11d9bd.
parent 0b7704d6
No related branches found
No related tags found
No related merge requests found
...@@ -243,16 +243,7 @@ def loop_tile_with_hint(tile_permutations, loop_extent, num_level, loop_hint): ...@@ -243,16 +243,7 @@ def loop_tile_with_hint(tile_permutations, loop_extent, num_level, loop_hint):
def loop_tile(loop_extent, num_level, loop_hint=None): def loop_tile(loop_extent, num_level, loop_hint=None):
tile_permutations = [] tile_permutations = []
if not loop_hint:
# check if the hint specifies the blocking at any level
is_valid_hint = False
if loop_hint is not None:
for level in range(num_level):
if loop_hint[level] is not None:
is_valid_hint = True
break
if not is_valid_hint:
recursive_tile(tile_permutations, [], loop_extent, 0, num_level) recursive_tile(tile_permutations, [], loop_extent, 0, num_level)
else: else:
loop_tile_with_hint(tile_permutations, loop_extent, num_level, loop_hint) loop_tile_with_hint(tile_permutations, loop_extent, num_level, loop_hint)
...@@ -260,30 +251,12 @@ def loop_tile(loop_extent, num_level, loop_hint=None): ...@@ -260,30 +251,12 @@ def loop_tile(loop_extent, num_level, loop_hint=None):
return tile_permutations return tile_permutations
def opt_valid_blocking(blocking_cache, resource, layer, blocking, schedule=None): def opt_valid_blocking(blocking_cache, resource, layer, blocking):
"""
Checks if a given blocking configuration is valid for a specific layer and resource setup.
"""
num_levels = resource.buffer_levels() num_levels = resource.buffer_levels()
blocking_tuple = list(zip(*blocking)) blocking_tuple = list(zip(*blocking))
dummy_partitioning = [(1,) * num_levels] * le.NUM dummy_partitioning = [(1,) * num_levels] * le.NUM
dummy_mapping_point = MappingPoint(None, list(blocking), dummy_partitioning) dummy_mapping_point = MappingPoint(None, list(blocking), dummy_partitioning)
"""
Check if blocking fits the blocking constraints in the schedule hint
"""
if schedule is not None and schedule.schedule_hint is not None:
for loop_index in range(le.NUM):
if loop_index in schedule.schedule_hint:
for level in range(num_levels):
if schedule.schedule_hint[loop_index][level] is not None:
if (
schedule.schedule_hint[loop_index][level][1] is not None
and blocking_tuple[level][loop_index]
!= schedule.schedule_hint[loop_index][level][1]
):
return False
""" """
Use cache to compute valid of first level Use cache to compute valid of first level
""" """
...@@ -331,7 +304,7 @@ def blocking_generator_function(resource, layer, schedule=None, verbose=False): ...@@ -331,7 +304,7 @@ def blocking_generator_function(resource, layer, schedule=None, verbose=False):
for tile in itertools.product(*all_tile_permutations): for tile in itertools.product(*all_tile_permutations):
# TODO here the generated is a list of lists, not a list of tuples # TODO here the generated is a list of lists, not a list of tuples
# if cost_model.valid_blocking_size(resource, dummy_mapping_point, layer): # if cost_model.valid_blocking_size(resource, dummy_mapping_point, layer):
if opt_valid_blocking(blocking_cache, resource, layer, tile, schedule): if opt_valid_blocking(blocking_cache, resource, layer, tile):
yield list(tile) yield list(tile)
...@@ -774,7 +747,7 @@ def blocking_partitioning_generator_function(resource, layer, schedule, verbose= ...@@ -774,7 +747,7 @@ def blocking_partitioning_generator_function(resource, layer, schedule, verbose=
print("") print("")
def opt_get_best_loop_order(resource, layer, point, schedule=None, verbose=False): def opt_get_best_loop_order(resource, layer, point, verbose=False):
""" """
[HW template right now: systolic array] [HW template right now: systolic array]
...@@ -819,43 +792,6 @@ def opt_get_best_loop_order(resource, layer, point, schedule=None, verbose=False ...@@ -819,43 +792,6 @@ def opt_get_best_loop_order(resource, layer, point, schedule=None, verbose=False
mapping_point = MappingPoint( mapping_point = MappingPoint(
list(zip(*dummy_loop_order)), blocking, partitioning, para_dim list(zip(*dummy_loop_order)), blocking, partitioning, para_dim
) )
# check mapping point fits the schedule hint constraint
valid_mapping_point = True
if schedule is not None and schedule.schedule_hint is not None:
for loop_index in range(le.NUM):
# if there's a blocking of 1, we can ignore the loop order constraint
if blocking[loop_index][level] == 1:
continue
if loop_index in schedule.schedule_hint:
if (
schedule.schedule_hint[loop_index][level] is not None
and schedule.schedule_hint[loop_index][level][0] is not None
):
loop_index_constraint = schedule.schedule_hint[loop_index][
level
][0]
# for negative value, it means we should count from the outermost loop
# e.g. -1 means the outermost loop
if loop_index_constraint < 0:
# find the max loop index (that is not 6) in the current level order
max_loop_index = max(
[i if i != 6 else -1 for i in curr_level_order]
)
loop_index_constraint = (
max_loop_index + 1 + loop_index_constraint
)
# if the loop index is not the same as the constraint, then it's not a valid mapping point
if curr_level_order[loop_index] != loop_index_constraint:
valid_mapping_point = False
break
if not valid_mapping_point:
continue
if ( if (
level <= 0 level <= 0
or resource.paras[level - 1].count <= 1 or resource.paras[level - 1].count <= 1
...@@ -933,7 +869,7 @@ def opt_mapping_point_generator_function(resource, layer, schedule=None, verbose ...@@ -933,7 +869,7 @@ def opt_mapping_point_generator_function(resource, layer, schedule=None, verbose
dummy_mapping_point = MappingPoint(None, blocking, partitioning, para_dim) dummy_mapping_point = MappingPoint(None, blocking, partitioning, para_dim)
# print "blocking_partitioning: ", blocking_partitioning # print "blocking_partitioning: ", blocking_partitioning
cost, loop_order = opt_get_best_loop_order( cost, loop_order = opt_get_best_loop_order(
resource, layer, dummy_mapping_point, schedule, verbose resource, layer, dummy_mapping_point, verbose
) )
if cost < smallest_cost: if cost < smallest_cost:
smallest_cost = cost smallest_cost = cost
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment