Skip to content

Commit

Permalink
added global config + restructured the project
Browse files Browse the repository at this point in the history
  • Loading branch information
djamelrassem committed Mar 31, 2023
1 parent afb76f2 commit c1a0ecc
Show file tree
Hide file tree
Showing 17 changed files with 449 additions and 316 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To use this project, you'll need to do the following:

1. Clone the repository to your local machine.
2. Make sure you have Anaconda installed.
3. Update the paths in the `env_api/utils/config/config.yaml` file to match your preferences.
3. Update the paths in the `config/config.yaml` file to match your preferences.

## Usage

Expand Down
File renamed without changes.
95 changes: 95 additions & 0 deletions config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal
import yaml, os


@dataclass
class TiramisuConfig:
tiramisu_path: str = ""
env_type: Literal["model", "cpu"] = "model"
tags_model_weights: str = ""


@dataclass
class DatasetConfig:
path: str = ""
offline: str = ""
save_path: str = ""
is_benchmark: bool = False
benchmark_cpp_files: str = ""
benchmark_path: str = ""


@dataclass
class Ray:
results: str = ""
restore_checkpoint: str = ""


@dataclass
class Experiment:
name: str = "test"
checkpoint_frequency: int = 10
checkpoint_num_to_keep: int = 10
training_iteration: int = 500
timesteps_total: int = 1000000
episode_reward_mean: float = 2
legality_speedup: float = 1.0

@dataclass
class PolicyNetwork:
vf_share_layers: bool = False
policy_hidden_layers : List[int] = field(
default_factory=lambda: [])
vf_hidden_layers : List[int] = field(
default_factory=lambda: [])
dropout_rate: float = 0.2
lr: float = 0.001


@dataclass
class AutoSchedulerConfig:

tiramisu: TiramisuConfig
dataset: DatasetConfig
ray: Ray
experiment:Experiment
policy_network:PolicyNetwork

def __post_init__(self):
if isinstance(self.tiramisu, dict):
self.tiramisu = TiramisuConfig(**self.tiramisu)
if isinstance(self.dataset, dict):
self.dataset = DatasetConfig(**self.dataset)
if isinstance(self.ray, dict):
self.ray = Ray(**self.ray)
if isinstance(self.experiment, dict):
self.experiment = Experiment(**self.experiment)
if isinstance(self.policy_network, dict):
self.policy_network = PolicyNetwork(**self.policy_network)


def read_yaml_file(path):
with open(path) as yaml_file:
return yaml_file.read()


def parse_yaml_file(yaml_string: str) -> Dict[Any, Any]:
return yaml.safe_load(yaml_string)


def dict_to_config(parsed_yaml: Dict[Any, Any]) -> AutoSchedulerConfig:
tiramisu = TiramisuConfig(**parsed_yaml["tiramisu"])
dataset = DatasetConfig(**parsed_yaml["dataset"])
ray = Ray(**parsed_yaml["ray"])
experiment = Experiment(**parsed_yaml["experiment"])
policy_network = PolicyNetwork(**parsed_yaml["policy_network"])
return AutoSchedulerConfig(tiramisu, dataset, ray,experiment,policy_network)


class Config(object):
config = None
@classmethod
def init(self, config_yaml="./config/config.yaml"):
parsed_yaml_dict = parse_yaml_file(read_yaml_file(config_yaml))
Config.config = dict_to_config(parsed_yaml_dict)
42 changes: 42 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
tiramisu:
tiramisu_path: "/scratch/dl5133/Env/tiramisu"
env_type: "model"
tags_model_weights: "/scratch/dl5133/Dev/RL-Agent/tiramisu-env/env_api/scheduler/models/model_release_version.pt"

dataset:
path: '/scratch/dl5133/Dev/RL-Agent/rl_autoscheduler/utils/program_generator/Dataset_multi/'
offline : '/scratch/dl5133/Dev/RL-Agent/tiramisu-env/datasets/merged_valid_programs.pkl'
save_path: '/scratch/dl5133/Dev/RL-Agent/tiramisu-env/datasets'
# When doing evaluation on the benchmark set the value to True
is_benchmark: False
benchmark_cpp_files: '/scratch/dl5133/Dev/RL-Agent/tiramisu-env/benchmark/'
benchmark_path: '/scratch/dl5133/Dev/RL-Agent/tiramisu-env/datasets/benchmark_P0.pkl'

ray:
results: "/scratch/dl5133/Dev/RL-Agent/tiramisu-env/ray_results"
restore_checkpoint: "/scratch/dl5133/Dev/RL-Agent/tiramisu-env/ray_results/All-actions-pseudo-beam-search-small-networks/PPO_TiramisuRlEnv_1133f_00000_0_2023-03-30_04-59-06/checkpoint_000160"

experiment:
name: "test"
checkpoint_frequency: 10
checkpoint_num_to_keep: 10
# The following 3 values are the values to stop the experiment if any of them is reached
training_iteration: 500
timesteps_total: 1000000
episode_reward_mean: 2
# Use this value to punish or tolerate illegal actions from being taken
legality_speedup: 0.9

policy_network:
# Set this to True if you want to use shared weights between policy and value function
vf_share_layers: False
policy_hidden_layers:
- 2048
- 512
- 64
# If vf_share_layers is true then, these values won't be taken for the value network
vf_hidden_layers:
- 512
- 64
dropout_rate: 0.2
lr: 0.001
30 changes: 19 additions & 11 deletions env_api/core/models/tiramisu_program.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
from pathlib import Path

from env_api.data.data_service import DataSetService
from config.config import Config


class TiramisuProgram():
Expand All @@ -28,24 +27,33 @@ def from_dict(cls, name: str, data: dict):
tiramisu_prog.schedules = data["schedules"]
# After taking the neccessary fields return the instance
return tiramisu_prog

def load_code_lines(self):
'''
This function loads the file code , it is necessary to generate legality check code and annotations
'''
if (self.name):
self.file_path = DataSetService.get_filepath(func_name=self.name)
file_path = self.file_path
# if self.name is None the program doesn't exist in the offline dataset but built from compiling
# if self.name has a value than it is fetched from the dataset, we need the full path to read
# the lines of the real function to execute legality code
func_name = self.name
file_name = func_name + "_generator.cpp"
file_path = (Config.config.dataset.benchmark_cpp_files
if Config.config.dataset.is_benchmark else Config.
config.dataset.path) + func_name + "/" + file_name
self.file_path = file_path
else:
file_path = self.file_path

with open(file_path, 'r') as f:
self.original_str = f.read()
self.func_folder = ('/'.join(Path(file_path).parts[:-1]) if
len(Path(file_path).parts) > 1 else '.') + '/'
self.func_folder = ('/'.join(Path(file_path).parts[:-1])
if len(Path(file_path).parts) > 1 else '.') + '/'
self.body = re.findall(r'(tiramisu::init(?s:.)+)tiramisu::codegen',
self.original_str)[0]
self.original_str)[0]
self.name = re.findall(r'tiramisu::init\(\"(\w+)\"\);',
self.original_str)[0]
self.comps = re.findall(r'computation (\w+)\(',
self.original_str)
self.original_str)[0]
self.comps = re.findall(r'computation (\w+)\(', self.original_str)
self.code_gen_line = re.findall(r'tiramisu::codegen\({.+;',
self.original_str)[0]
buffers_vect = re.findall(r'{(.+)}', self.code_gen_line)[0]
Expand Down
Loading

0 comments on commit c1a0ecc

Please sign in to comment.