Skip to content

Commit

Permalink
add CME class typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Dec 23, 2024
1 parent aa18ccd commit 1cf2de8
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions zigzag/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from functools import lru_cache
from math import ceil
from typing import TypedDict
from typing import TYPE_CHECKING, Any, TypedDict

import numpy as np

Expand All @@ -20,6 +20,10 @@
from zigzag.utils import json_repr_handler, pickle_deepcopy
from zigzag.workload.layer_node import LayerNode

if TYPE_CHECKING:
from zigzag.workload.layer_attributes import MemoryOperandLinks


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -294,6 +298,24 @@ class CostModelEvaluation(CostModelEvaluationABC):
After initialization, the cost model evaluation is run.
"""

accelerator: Accelerator
layer: LayerNode
spatial_mapping: SpatialMappingInternal
temporal_mapping: TemporalMapping
access_same_data_considered_as_no_access: bool
cycles_per_op: float
mem_level_list: list[MemoryLevel]
mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]]
mem_size_dict: dict[MemoryOperand, list[int]]
mem_r_bw_dict: tuple[dict[MemoryOperand, list[int]], dict[MemoryOperand, list[int]]]
mem_r_bw_min_dict: tuple[dict[MemoryOperand, list[int]], dict[MemoryOperand, list[int]]]
mem_sharing_tuple: tuple[tuple[tuple[MemoryOperand, int], ...], ...]
memory_operand_links: "MemoryOperandLinks"
spatial_mapping_dict_int: Any
mapping: Mapping
mapping_int: Mapping
active_mem_level: dict[LayerOperand, int]

def __init__(
self,
*,
Expand All @@ -317,7 +339,6 @@ def __init__(
self.spatial_mapping_int = spatial_mapping_int # the original spatial mapping without decimal
self.temporal_mapping = temporal_mapping
self.access_same_data_considered_as_no_access = access_same_data_considered_as_no_access

self.mem_level_list = accelerator.memory_hierarchy.mem_level_list
self.mem_hierarchy_dict = accelerator.mem_hierarchy_dict
self.mem_size_dict = accelerator.mem_size_dict
Expand Down

0 comments on commit 1cf2de8

Please sign in to comment.