From 470a2482e9bd6dea90073482ba8fa058faeb4f25 Mon Sep 17 00:00:00 2001
From: sgauthamr2001 <sgauthamr2001@gmail.com>
Date: Sun, 5 Jan 2025 00:02:49 -0800
Subject: [PATCH] Correctly use stride when calculating number of input
 accesses

---
 src/interstellar/cost_model.py | 27 +++++++++++++++++++++------
 1 file changed, 21 insertions(+), 6 deletions(-)

diff --git a/src/interstellar/cost_model.py b/src/interstellar/cost_model.py
index e0ccec7..7e2904d 100644
--- a/src/interstellar/cost_model.py
+++ b/src/interstellar/cost_model.py
@@ -143,13 +143,28 @@ def get_if_access(resource, point, layer, mac_capacity=1):
                 for i in range(le.NUM):
                     if i not in irrelevant_loops:
                         if i == le.OX:
-                            block_size *= point.loop_blockings[i][lower_level] + (
-                                point.loop_blockings[le.FX][lower_level] - 1
-                            )
+                            stride = layer.wstd
+                            # stride should be ignored for L0 level or if FX/FY = 1
+                            if (
+                                lower_level == 0
+                                or point.loop_blockings[le.FX][lower_level] == 1
+                            ):
+                                stride = 1
+
+                            block_size *= point.loop_blockings[i][
+                                lower_level
+                            ] * stride + (point.loop_blockings[le.FX][lower_level] - 1)
                         elif i == le.OY:
-                            block_size *= point.loop_blockings[i][lower_level] + (
-                                point.loop_blockings[le.FY][lower_level] - 1
-                            )
+                            stride = layer.hstd
+                            # stride should be ignored for L0 level or if FX/FY = 1
+                            if (
+                                lower_level == 0
+                                or point.loop_blockings[le.FY][lower_level] == 1
+                            ):
+                                stride = 1
+                            block_size *= point.loop_blockings[i][
+                                lower_level
+                            ] * stride + (point.loop_blockings[le.FY][lower_level] - 1)
                         else:
                             block_size *= point.loop_blockings[i][lower_level]
                         block_size *= point.loop_partitionings[i][lower_level]
-- 
GitLab