Newer
Older
def extract_arch_info(arch_file):
with open(arch_file) as json_data_file:
data = json.load(json_data_file)
assert data["mem_levels"] == len(data["capacity"]), \
"capacity list is invalid, too many or too few elements"
assert data["mem_levels"] == len(data["access_cost"]), \
"access_cost list is invalid, too many or too few elements"
assert data["mem_levels"] == len(data["parallel_count"]), \
"parallel_count list is invalid, too many or too few elements"
num_bytes = data["precision"] / 8
if type(data["capacity"][0]) is list:
capacity_list = [ [x / num_bytes for x in data["capacity"][i]] for i in range(len(data["capacity"])) ]
else:
capacity_list = [x / num_bytes for x in data["capacity"] ]
data["capacity"] = capacity_list
if "static_cost" not in data:
data["static_cost"] = [0, ] * data["mem_levels"]
else:
assert data["mem_levels"] == len(data["static_cost"]), \
"static_cost list is invalid, too many or too few elements"
if "mac_capacity" not in data:
data["mac_capacity"] = 0
if "parallel_mode" not in data:
data["parallel_mode"] = [0, ] * data["mem_levels"]
for level in range(data["mem_levels"]):
if data["parallel_count"][level] != 1:
data["parallel_mode"][level] = 1
else:
assert data["mem_levels"] == len(data["parallel_mode"]), \
"parallel_mode list is invalid, too many or too few elements"
if "array_dim" not in data:
data["array_dim"] = None
if "utilization_threshold" not in data:
data["utilization_threshold"] = 0.0
if "invalid_underutilized" not in data:
data["invalid_underutilized"] = True
if "memory_partitions" not in data:
data["memory_partitions"] = [[0,0,0],[0,0,0],[0,0,0]]
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
return data
def extract_network_info(network_file):
with open(network_file) as json_data_file:
data = json.load(json_data_file)
if "batch_size" not in data:
data["batch_size"] = 1
if "stride_width" not in data:
data["stride_width"] = 1
if "stride_height" not in data:
data["stride_height"] = 1
layer_summary = data.values()
data['layer_info'] = layer_summary
data['layer_name'] = os.path.splitext(os.path.basename(network_file))[0]
return data
def extract_schedule_info(schedule_file, num_levels):
with open(schedule_file) as json_data_file:
data = json.load(json_data_file)
schedule = {}
hint = data["schedule_hint"]
schedule_hint = {}
for loop in hint:
schedule_hint[le.loop_table[loop]] = [None,]*num_levels
for level in hint[loop]:
level_index = int(level.lstrip('level'))
schedule_hint[le.loop_table[loop]][level_index] = [None,]*3
if "order" in hint[loop][level]:
schedule_hint[le.loop_table[loop]][level_index][0] = hint[loop][level]["order"]
if "blocking_size" in hint[loop][level]:
schedule_hint[le.loop_table[loop]][level_index][1] = hint[loop][level]["blocking_size"]
if "partitioning_size" in hint[loop][level]:
schedule_hint[le.loop_table[loop]][level_index][2] = hint[loop][level]["partitioning_size"]
schedule["schedule_hint"] = schedule_hint
if "partition_loops" not in data:
schedule["partition_loops"] = None
else:
schedule["partition_loops"] = data["partition_loops"]
#TODO partition at dimension
return schedule
def extract_info(args):
arch_info = extract_arch_info(args.arch)
network_info = extract_network_info(args.network)
schedule_info = extract_schedule_info(args.schedule, arch_info["mem_levels"]) if args.schedule else None
return arch_info, network_info, schedule_info