Skip to content

Commit

Permalink
fix: add test
Browse files Browse the repository at this point in the history
  • Loading branch information
AoyuQC committed Nov 26, 2023
1 parent 3551bb9 commit abb9f13
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/braket/experimental/algorithms/qc_qrl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from unittest.mock import patch

import pytest

from braket.experimental.algorithms.qc_qrl.utility.RetroRLAgent import RetroRLAgent
from braket.experimental.algorithms.qc_qrl.utility.RetroGateModel import RetroGateModel

np.set_printoptions(precision=4, edgeitems=10, linewidth=150, suppress=True)

@pytest.fixture
def common_agent_param():
agent_param = {}
# initial the RetroRLModel object
init_param = {}
method = ['retro-rl', 'retro-qrl']

for mt in method:
if mt == 'retro-rl':
init_param[mt] = {}
init_param[mt]['param'] = ['inputsize', 'middlesize', 'outputsize']
elif mt == 'retro-qrl':
init_param[mt] = {}
init_param[mt]['param'] = ['n_qubits', 'device', 'framework', 'shots', 'layers']

agent_param['init_param'] = init_param
# train_mode can be: "local-instance", "local-job", "hybrid-job"
train_mode = "hybrid-job"

data_path = 'data'
s3_data_path = None

agent_param["data_path"] = data_path
agent_param["s3_data_path"]=s3_data_path
agent_param["train_mode"] = train_mode
agent_param["episodes"] = 2

return agent_param

def test_quantum_circuit_parameters(common_agent_param):
agent_param = common_agent_param

model_param={}
method = 'retro-qrl'
model_param[method] = {}
model_param[method]['n_qubits'] = [8]
model_param[method]['device'] = ['local']
model_param[method]['framework'] = ['pennylane']
model_param[method]['shots'] = [100]
model_param[method]['layers'] = [1]

agent_param['model_param'] = model_param

n_qubits = model_param[method]['n_qubits'][0]
device = model_param[method]['device'][0]
framework = model_param[method]['framework'][0]
shots = model_param[method]['shots'][0]
layers = model_param[method]['layers'][0]

model_name = "{}_{}_{}_{}_{}".format(n_qubits, device, framework, shots, layers)
agent_param["model_name"] = model_name

agent_param["train_mode"]="local-instance"

retro_qrl_agent = RetroRLAgent(build_model=True, method=method, **agent_param)

quantum_param_sum = 0
for param in retro_qrl_agent.NN.parameters():
quantum_param_sum = quantum_param_sum + param.numel()

assert quantum_param_sum == model_param[method]["n_qubits"][0]

def test_classical_circuit_parameters(common_agent_param):
agent_param = common_agent_param

model_param={}
method = 'retro-rl'
model_param[method] = {}
model_param[method]['inputsize'] = [256]
model_param[method]['middlesize'] = [256]
model_param[method]['outputsize'] = [1]

agent_param['model_param'] = model_param
model_name = f"{model_param[method]['inputsize'][0]}_{model_param[method]['middlesize'][0]}_{model_param[method]['outputsize'][0]}"
agent_param["model_name"] = model_name

agent_param["train_mode"]="local-instance"

retro_crl_agent = RetroRLAgent(build_model=True, method=method, **agent_param)

classical_param_sum = 0
for param in retro_crl_agent.NN.parameters():
classical_param_sum = classical_param_sum + param.numel()

assert classical_param_sum == model_param[method]['inputsize'][0]*model_param[method]['middlesize'][0]*model_param[method]['outputsize'][0]

0 comments on commit abb9f13

Please sign in to comment.