diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b4fece12..71a8605c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -2,13 +2,21 @@ name: MLGO CI -on: [push, repository_dispatch, pull_request] +permissions: + contents: read + +on: + push: + branches: + - 'main' + repository_dispatch: + pull_request: jobs: LicenseCheck: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - run: ./check-license.sh Envvars: runs-on: ubuntu-latest @@ -46,17 +54,17 @@ jobs: - task: Test cmd: pytest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Python With Cached pip Packages if: needs.Envvars.outputs.do_cache == '1' - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pipenv' cache-dependency-path: Pipfile.lock - name: Install Python, no cache if: needs.Envvars.outputs.do_cache == '0' - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install Pipenv diff --git a/compiler_opt/es/es_worker.py b/compiler_opt/es/es_worker.py deleted file mode 100644 index 22b3ff42..00000000 --- a/compiler_opt/es/es_worker.py +++ /dev/null @@ -1,40 +0,0 @@ -# coding=utf-8 -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Worker for ES Training.""" - -import gin -from typing import List -from compiler_opt.distributed import worker -from compiler_opt.rl import corpus - - -@gin.configurable -class ESWorker(worker.Worker): - """Temporary placeholder worker. - Each time a worker is called, the function value - it will return increases.""" - - def __init__(self, arg, *, kwarg): - self._arg = arg - self._kwarg = kwarg - self.function_value = 0.0 - - def temp_compile(self, policy: bytes, - samples: List[corpus.ModuleSpec]) -> float: - if policy and samples: - self.function_value += 1.0 - return self.function_value - else: - return 0.0 diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index 34bdccd3..dc44dd0f 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -137,7 +137,7 @@ def reset(self): def __call__(self, k: int, n: int = 20) -> List[ModuleSpec]: """ Args: - k: number of modules to sample + k: The number of corpus elements to sample. n: number of buckets to use """ raise NotImplementedError() @@ -159,7 +159,7 @@ def __call__(self, k: int, n: int = 20) -> List[ModuleSpec]: """ Args: module_specs: list of module_specs to sample from - k: number of modules to sample + k: The number of corpus elements (modules) to sample n: number of buckets to use """ # Credits to yundi@ for the highly optimized algo. @@ -210,7 +210,7 @@ def reset(self): def __call__(self, k: Optional[int], n: int = 10) -> List[ModuleSpec]: """ Args: - k: number of modules to sample + k: number of corpus elements (modules) to sample n: ignored Raises: CorpusExhaustedError if there are fewer than k elements left to sample in @@ -240,8 +240,8 @@ def __call__(self, k: int, n: int = 10) -> List[ModuleSpec]: """Returns the entire corpus as a list of module specs. Args: - k: The number of modules to sample. This must be equal to the number - of modules in the corpus. + k: The number of corpus elements to sample. Given a corpus element is + the whole corpus in this case, k should always be one. n: Ignored for this sampler. Returns: @@ -251,10 +251,10 @@ def __call__(self, k: int, n: int = 10) -> List[ModuleSpec]: ValueError: If the requested number of modules is not equal to the number of modules in the corpus. """ - if len(self._module_specs) != k: + if k != 1: raise ValueError( - f'The number of modules requested {k} is not equal to ' - f'the number of modules in the corpus, {len(self._module_specs)}') + f'The number of corpus elements requested {k} is not equal to ' + f'1, the number of corpus elements (whole corpora) present.') return list(self._module_specs) @@ -418,9 +418,9 @@ def reset(self): self._sampler.reset() def sample(self, k: int, sort: bool = False) -> List[ModuleSpec]: - """Samples `k` module_specs, optionally sorting by size descending. + """Samples `k` corpus elements, optionally sorting by size descending. - Use load_corpus_element to get LoadedModuleSpecs - this allows the user + Use load_module_spec to get LoadedModuleSpecs - this allows the user to decide how the loading should happen (e.g. may want to use a threadpool) """ # Note: sampler is intentionally defaulted to a mutable object, as the diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py index fc7f41f0..c62e2c03 100644 --- a/compiler_opt/rl/corpus_test.py +++ b/compiler_opt/rl/corpus_test.py @@ -277,7 +277,7 @@ def test_whole_corpus_sampler(self): corpus.ModuleSpec(name='large', size=100) ], sampler_type=corpus.WholeCorpusSampler) - sample = cps.sample(4, sort=True) + sample = cps.sample(1, sort=True) self.assertLen(sample, 4) self.assertEqual(sample[0].name, 'large') self.assertEqual(sample[1].name, 'middle') @@ -293,7 +293,7 @@ def test_whole_corpus_sampler_invalid_count(self): ], sampler_type=corpus.WholeCorpusSampler) with self.assertRaises(ValueError): - cps.sample(1) + cps.sample(2) def test_filter(self): cps = corpus.create_corpus_for_testing( diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py index 1bd96afb..904fd388 100644 --- a/compiler_opt/rl/env.py +++ b/compiler_opt/rl/env.py @@ -359,6 +359,7 @@ def reset(self, module: corpus.LoadedModuleSpec): self._clang_generator.send(None) # pytype: disable=attribute-error self._iclang, self._clang = self._clang_generator.send(module) + # pytype: enable=attribute-error return self._get_observation() def step(self, action: np.ndarray):