Skip to content

Commit

Permalink
Make owner the calling address, and creator is the campaign manager
Browse files Browse the repository at this point in the history
  • Loading branch information
Nenad committed Jun 12, 2024
1 parent b24965d commit 347e930
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
9 changes: 5 additions & 4 deletions listings/applications/advanced_factory/src/tests.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ fn test_deploy_campaign() {

let mut spy = spy_events(SpyOn::One(factory.contract_address));

let campaign_owner: ContractAddress = contract_address_const::<'campaign_owner'>();
start_cheat_caller_address(factory.contract_address, campaign_owner);
let campaign_creator: ContractAddress = contract_address_const::<'campaign_creator'>();
start_cheat_caller_address(factory.contract_address, campaign_creator);

let title: ByteArray = "New campaign";
let description: ByteArray = "Some description";
Expand All @@ -77,9 +77,10 @@ fn test_deploy_campaign() {
assert_eq!(details.status, Status::PENDING);
assert_eq!(details.token, token);
assert_eq!(details.total_contributions, 0);
assert_eq!(details.creator, campaign_creator);

let campaign_ownable = IOwnableDispatcher { contract_address: campaign_address };
assert_eq!(campaign_ownable.owner(), campaign_owner);
assert_eq!(campaign_ownable.owner(), factory.contract_address);

spy
.assert_emitted(
Expand All @@ -88,7 +89,7 @@ fn test_deploy_campaign() {
factory.contract_address,
CampaignFactory::Event::CampaignCreated(
CampaignFactory::CampaignCreated {
caller: campaign_owner, contract_address: campaign_address
caller: campaign_creator, contract_address: campaign_address
}
)
)
Expand Down
32 changes: 20 additions & 12 deletions listings/applications/crowdfunding/src/campaign.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub enum Status {

#[derive(Drop, Serde)]
pub struct Details {
pub creator: ContractAddress,
pub target: u256,
pub title: ByteArray,
pub end_time: u64,
Expand Down Expand Up @@ -42,7 +43,7 @@ pub mod Campaign {
use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait};
use starknet::{
ClassHash, ContractAddress, SyscallResultTrait, get_block_timestamp, contract_address_const,
get_caller_address, get_contract_address
get_caller_address, get_contract_address, class_hash::class_hash_const
};
use components::ownable::ownable_component;
use super::contributions::contributable_component;
Expand Down Expand Up @@ -120,7 +121,7 @@ pub mod Campaign {
}

pub mod Errors {
pub const NOT_FACTORY: felt252 = 'Caller not factory';
pub const NOT_CREATOR: felt252 = 'Not creator';
pub const ENDED: felt252 = 'Campaign already ended';
pub const NOT_PENDING: felt252 = 'Campaign not pending';
pub const STILL_ACTIVE: felt252 = 'Campaign not ended';
Expand All @@ -131,7 +132,7 @@ pub mod Campaign {
pub const TRANSFER_FAILED: felt252 = 'Transfer failed';
pub const TITLE_EMPTY: felt252 = 'Title empty';
pub const CLASS_HASH_ZERO: felt252 = 'Class hash cannot be zero';
pub const FACTORY_ZERO: felt252 = 'Factory address cannot be zero';
pub const ZERO_ADDRESS_CALLER: felt252 = 'Caller cannot be zero';
pub const CREATOR_ZERO: felt252 = 'Creator address cannot be zero';
pub const TARGET_NOT_REACHED: felt252 = 'Target not reached';
pub const TARGET_ALREADY_REACHED: felt252 = 'Target already reached';
Expand All @@ -141,14 +142,14 @@ pub mod Campaign {
#[constructor]
fn constructor(
ref self: ContractState,
owner: ContractAddress,
creator: ContractAddress,
title: ByteArray,
description: ByteArray,
target: u256,
duration: u64,
token_address: ContractAddress
) {
assert(owner.is_non_zero(), Errors::CREATOR_ZERO);
assert(creator.is_non_zero(), Errors::CREATOR_ZERO);
assert(title.len() > 0, Errors::TITLE_EMPTY);
assert(target > 0, Errors::ZERO_TARGET);
assert(duration > 0, Errors::ZERO_DURATION);
Expand All @@ -159,15 +160,15 @@ pub mod Campaign {
self.target.write(target);
self.description.write(description);
self.end_time.write(get_block_timestamp() + duration);
self.creator.write(get_caller_address());
self.ownable._init(owner);
self.creator.write(creator);
self.ownable._init(get_caller_address());
self.status.write(Status::PENDING)
}

#[abi(embed_v0)]
impl Campaign of super::ICampaign<ContractState> {
fn claim(ref self: ContractState) {
self.ownable._assert_only_owner();
self._assert_only_creator();
assert(self._is_active(), Errors::ENDED);
assert(self._is_target_reached(), Errors::TARGET_NOT_REACHED);
// no need to check end_time, as the owner can prematurely end the campaign
Expand All @@ -191,7 +192,7 @@ pub mod Campaign {
}

fn close(ref self: ContractState, reason: ByteArray) {
self.ownable._assert_only_owner();
self._assert_only_creator();
assert(self._is_active(), Errors::ENDED);

self.status.write(Status::CLOSED);
Expand Down Expand Up @@ -222,6 +223,7 @@ pub mod Campaign {

fn get_details(self: @ContractState) -> Details {
Details {
creator: self.creator.read(),
title: self.title.read(),
description: self.description.read(),
target: self.target.read(),
Expand All @@ -233,7 +235,7 @@ pub mod Campaign {
}

fn start(ref self: ContractState) {
self.ownable._assert_only_owner();
self._assert_only_creator();
assert(self.status.read() == Status::PENDING, Errors::NOT_PENDING);

self.status.write(Status::ACTIVE);
Expand All @@ -242,8 +244,8 @@ pub mod Campaign {
}

fn upgrade(ref self: ContractState, impl_hash: ClassHash) -> Result<(), Array<felt252>> {
if get_caller_address() != self.creator.read() {
return Result::Err(array![Errors::NOT_FACTORY]);
if get_caller_address() != self.ownable.owner() {
return Result::Err(array![components::ownable::Errors::UNAUTHORIZED]);
}
if impl_hash.is_zero() {
return Result::Err(array![Errors::CLASS_HASH_ZERO]);
Expand Down Expand Up @@ -282,6 +284,12 @@ pub mod Campaign {

#[generate_trait]
impl CampaignInternalImpl of CampaignInternalTrait {
fn _assert_only_creator(self: @ContractState) {
let caller = get_caller_address();
assert(caller.is_non_zero(), Errors::ZERO_ADDRESS_CALLER);
assert(caller == self.creator.read(), Errors::NOT_CREATOR);
}

fn _is_expired(self: @ContractState) -> bool {
get_block_timestamp() < self.end_time.read()
}
Expand Down
19 changes: 10 additions & 9 deletions listings/applications/crowdfunding/src/tests.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ use components::ownable::{IOwnableDispatcher, IOwnableDispatcherTrait};
fn deploy_with(
title: ByteArray, description: ByteArray, target: u256, duration: u64, token: ContractAddress
) -> ICampaignDispatcher {
let owner = contract_address_const::<'owner'>();
let creator = contract_address_const::<'creator'>();
let mut calldata: Array::<felt252> = array![];
((owner, title, description, target), duration, token).serialize(ref calldata);
((creator, title, description, target), duration, token).serialize(ref calldata);

let contract = declare("Campaign").unwrap();
let contract_address = contract.precalculate_address(@calldata);
let factory = contract_address_const::<'factory'>();
start_cheat_caller_address(contract_address, factory);
let owner = contract_address_const::<'owner'>();
start_cheat_caller_address(contract_address, owner);

contract.deploy(@calldata).unwrap();

Expand All @@ -50,6 +50,7 @@ fn test_deploy() {
assert_eq!(details.status, Status::PENDING);
assert_eq!(details.token, contract_address_const::<'token'>());
assert_eq!(details.total_contributions, 0);
assert_eq!(details.creator, contract_address_const::<'creator'>());

let owner: ContractAddress = contract_address_const::<'owner'>();
let campaign_ownable = IOwnableDispatcher { contract_address: campaign.contract_address };
Expand All @@ -64,8 +65,8 @@ fn test_upgrade_class_hash() {

let new_class_hash = declare("MockContract").unwrap().class_hash;

let factory = contract_address_const::<'factory'>();
start_cheat_caller_address(campaign.contract_address, factory);
let owner = contract_address_const::<'owner'>();
start_cheat_caller_address(campaign.contract_address, owner);

if let Result::Err(errs) = campaign.upgrade(new_class_hash) {
panic(errs)
Expand All @@ -85,14 +86,14 @@ fn test_upgrade_class_hash() {
}

#[test]
#[should_panic(expected: 'Caller not factory')]
#[should_panic(expected: 'Not owner')]
fn test_upgrade_class_hash_fail() {
let campaign = deploy();

let new_class_hash = declare("MockContract").unwrap().class_hash;

let owner = contract_address_const::<'owner'>();
start_cheat_caller_address(campaign.contract_address, owner);
let random_address = contract_address_const::<'random_address'>();
start_cheat_caller_address(campaign.contract_address, random_address);

if let Result::Err(errs) = campaign.upgrade(new_class_hash) {
panic(errs)
Expand Down

0 comments on commit 347e930

Please sign in to comment.