Skip to content

Commit

Permalink
Update endpoints with alternative for calling from Motoko
Browse files Browse the repository at this point in the history
Motoko does not support float32, so we provide an alternative interface
  • Loading branch information
icppWorld committed Mar 30, 2024
1 parent ad70f82 commit 4225588
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 24 deletions.
30 changes: 15 additions & 15 deletions icpp_llama2/native/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,16 +584,16 @@ int main() {
expected_response, silent_on_trap, my_principal);

// -----------------------------------------------------------------------------------------
// A new chat
// A new chat, pretend it being called from Motoko, using float64
// '()' -> '(variant { Ok = 200 : nat16 })'
mockIC.run_test("new_chat", new_chat, "4449444c0000",
"4449444c026c019aa1b2f90c7a6b01bc8a0100010100c800",
silent_on_trap, my_principal);
// With temperature=0.0 & topp=0.9, still greedy argmax sampling -> the story will be the same every time
// '(record {prompt = "" : text; steps = 100 : nat64; temperature = 0.0 : float32; topp = 0.9 : float32; rng_seed = 0 : nat64;})'
// '(record {prompt = "" : text; steps = 100 : nat64; temperature = 0.0 : float64; topp = 0.9 : float64; rng_seed = 0 : nat64;})'
mockIC.run_test(
"inference 2", inference,
"4449444c016c05b4e8c2e40373bbb885e80473a7f7b9a00878c5c8cea60878a4a3e1aa0b710100000000006666663f6400000000000000000000000000000000",
"inference_mo 2", inference_mo,
"4449444c016c05b4e8c2e40372bbb885e80472a7f7b9a00878c5c8cea60878a4a3e1aa0b7101000000000000000000cdccccccccccec3f6400000000000000000000000000000000",
expected_response, silent_on_trap, my_principal);

// -----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -918,10 +918,10 @@ int main() {
expected_response, silent_on_trap, my_principal);

// ------------------------------------------------------------------------
// Create the story for nft_id=1 with token_id="token-B"
// Create the story for nft_id=1 with token_id="token-B", being called from Motoko with PromptMo (float64 types)

// With temperature=0.0: greedy argmax sampling -> the story will be the same every time
// '(record {token_id = "token-B" : text}, record{ prompt = "Charles had a boat." : text; steps = 100 : nat64; temperature = 0.0 : float32; topp = 1.0 : float32; rng_seed = 0 : nat64;})'
// '(record {token_id = "token-B" : text}, record{ prompt = "Charles had a boat." : text; steps = 100 : nat64; temperature = 0.0 : float64; topp = 1.0 : float64; rng_seed = 0 : nat64;})'
expected_response = "-to-do-B-";
if (model_to_use == 1) {
// -> '(variant { Ok = record { inference = "...some story..." : text;} })'
Expand All @@ -935,17 +935,17 @@ int main() {
// A regular user cannot create new stories for NFTs
// -> '(variant { Err = variant { Other = "Access Denied - You are not authorized to call this function." } })'
mockIC.run_test(
"nft_story_start 1 Err test", nft_story_start,
"4449444c026c01a1a1c1da02716c05b4e8c2e40373bbb885e80473a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d42000000000000803f6400000000000000000000000000000013436861726c657320686164206120626f61742e",
"nft_story_start_mo 1 Err test", nft_story_start_mo,
"4449444c026c01a1a1c1da02716c05b4e8c2e40372bbb885e80472a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d420000000000000000000000000000f03f6400000000000000000000000000000013436861726c657320686164206120626f61742e",
"4449444c026b01b0ad8fcd0c716b01c5fed20100010100003d4163636573732044656e696564202d20596f7520617265206e6f7420617574686f72697a656420746f2063616c6c20746869732066756e6374696f6e2e",
silent_on_trap, your_principal);
mockIC.run_test(
"nft_story_start 1", nft_story_start,
"4449444c026c01a1a1c1da02716c05b4e8c2e40373bbb885e80473a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d42000000000000803f6400000000000000000000000000000013436861726c657320686164206120626f61742e",
"nft_story_start_mo 1", nft_story_start_mo,
"4449444c026c01a1a1c1da02716c05b4e8c2e40372bbb885e80472a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d420000000000000000000000000000f03f6400000000000000000000000000000013436861726c657320686164206120626f61742e",
expected_response, silent_on_trap, my_principal);

// With temperature=0.0: greedy argmax sampling -> the story will be the same every time
// '(record {token_id = "token-B" : text}, record{ prompt = "" : text; steps = 100 : nat64; temperature = 0.0 : float32; topp = 1.0 : float32; rng_seed = 0 : nat64;})'
// '(record {token_id = "token-B" : text}, record{ prompt = "" : text; steps = 100 : nat64; temperature = 0.0 : float64; topp = 1.0 : float64; rng_seed = 0 : nat64;})'
expected_response = "-to-do-B-";
if (model_to_use == 1) {
// -> '(variant { Ok = record { inference = "...some story..." : text;} })'
Expand All @@ -959,13 +959,13 @@ int main() {
// A regular user cannot continue stories for NFTs
// -> '(variant { Err = variant { Other = "Access Denied - You are not authorized to call this function." } })'
mockIC.run_test(
"nft_story_continue Err test", nft_story_continue,
"4449444c026c01a1a1c1da02716c05b4e8c2e40373bbb885e80473a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d42000000000000803f6400000000000000000000000000000000",
"nft_story_continue_mo Err test", nft_story_continue_mo,
"4449444c026c01a1a1c1da02716c05b4e8c2e40372bbb885e80472a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d420000000000000000000000000000f03f6400000000000000000000000000000000",
"4449444c026b01b0ad8fcd0c716b01c5fed20100010100003d4163636573732044656e696564202d20596f7520617265206e6f7420617574686f72697a656420746f2063616c6c20746869732066756e6374696f6e2e",
silent_on_trap, your_principal);
mockIC.run_test(
"nft_story_continue 1", nft_story_continue,
"4449444c026c01a1a1c1da02716c05b4e8c2e40373bbb885e80473a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d42000000000000803f6400000000000000000000000000000000",
"nft_story_continue_mo 1", nft_story_continue_mo,
"4449444c026c01a1a1c1da02716c05b4e8c2e40372bbb885e80472a7f7b9a00878c5c8cea60878a4a3e1aa0b7102000107746f6b656e2d420000000000000000000000000000f03f6400000000000000000000000000000000",
expected_response, silent_on_trap, my_principal);

// ------------------------------------------------------------------------
Expand Down
23 changes: 20 additions & 3 deletions icpp_llama2/src/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ std::string generate(IC_API ic_api, Chat *chat, Transformer *transformer,
}

// Inference endpoint for ICGPT, with story ownership based on principal of caller
void inference() {
void inference() { inference_(false); }
void inference_mo() {
inference_(true);
} // Use this when calling from Motoko, with float64
void inference_(bool from_motoko) {
IC_API ic_api(CanisterUpdate{std::string(__func__)}, false);
if (!is_canister_mode_chat_principal()) {
std::string error_msg =
Expand All @@ -175,14 +179,27 @@ void inference() {
if (!is_ready_and_authorized(ic_api)) return;

// Get the Prompt from the wire
PromptMo wire_prompt_motoko; // Motoko does not support float32, uses float64
Prompt wire_prompt;
CandidTypeRecord r_in;
r_in.append("prompt", CandidTypeText{&wire_prompt.prompt});
r_in.append("steps", CandidTypeNat64{&wire_prompt.steps});
r_in.append("temperature", CandidTypeFloat32{&wire_prompt.temperature});
r_in.append("topp", CandidTypeFloat32{&wire_prompt.topp});
if (from_motoko) {
r_in.append("temperature",
CandidTypeFloat64{&wire_prompt_motoko.temperature});
r_in.append("topp", CandidTypeFloat64{&wire_prompt_motoko.topp});
} else {
r_in.append("temperature", CandidTypeFloat32{&wire_prompt.temperature});
r_in.append("topp", CandidTypeFloat32{&wire_prompt.topp});
}
r_in.append("rng_seed", CandidTypeNat64{&wire_prompt.rng_seed});
ic_api.from_wire(r_in);

if (from_motoko) {
wire_prompt.temperature =
static_cast<float>(wire_prompt_motoko.temperature);
wire_prompt.topp = static_cast<float>(wire_prompt_motoko.topp);
}
// print_prompt(wire_prompt);

CandidTypePrincipal caller = ic_api.get_caller();
Expand Down
2 changes: 2 additions & 0 deletions icpp_llama2/src/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include "prompt.h"

void inference() WASM_SYMBOL_EXPORTED("canister_update inference");
void inference_mo() WASM_SYMBOL_EXPORTED("canister_update inference_mo");

void inference_(bool from_motoko);
std::string do_inference(IC_API &ic_api, Prompt wire_prompt, Chat *chat,
std::string *output_history,
MetadataUser *metadata_user, bool *error);
12 changes: 12 additions & 0 deletions icpp_llama2/src/llama2.did
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ type Prompt = record {
rng_seed : nat64;
};

// Motoko does not support float32, so we use float64, and then map PromptMo onto Prompt
type PromptMo = record {
prompt : text;
steps : nat64;
temperature : float64;
topp : float64;
rng_seed : nat64;
};

type Config = record {
dim : int;
hidden_dim : int;
Expand Down Expand Up @@ -182,6 +191,7 @@ service : {
// Chat endpoints for canister_mode=chat-principal
new_chat : () -> (StatusCodeRecordResult);
inference : (Prompt) -> (InferenceRecordResult);
inference_mo : (PromptMo) -> (InferenceRecordResult);

// admin endpoints
whoami : () -> (text) query;
Expand All @@ -198,6 +208,8 @@ service : {
nft_metadata : () -> (NFTCollectionRecordResult) query;
nft_mint : (NFT) -> (StatusCodeRecordResult);
nft_story_start : (NFT, Prompt) -> (InferenceRecordResult);
nft_story_start_mo : (NFT, PromptMo) -> (InferenceRecordResult);
nft_story_continue : (NFT, Prompt) -> (InferenceRecordResult);
nft_story_continue_mo : (NFT, PromptMo) -> (InferenceRecordResult);
nft_get_story : (NFT) -> (StoryRecordResult) query;
};
25 changes: 20 additions & 5 deletions icpp_llama2/src/nft_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,12 @@ void nft_mint() {
}

// Endpoints for the story of an NFT, callable by whitelisted principals only
void nft_story_start() { nft_story_(true); }
void nft_story_continue() { nft_story_(false); }
void nft_story_start() { nft_story_(true, false); }
void nft_story_start_mo() { nft_story_(true, true); }
void nft_story_continue() { nft_story_(false, false); }
void nft_story_continue_mo() { nft_story_(false, true); }

void nft_story_(bool story_start) {
void nft_story_(bool story_start, bool from_motoko) {
IC_API ic_api(CanisterUpdate{std::string(__func__)}, false);
if (!is_canister_mode_nft_ordinal()) {
std::string error_msg = "Access Denied - Canister is not in NFT mode.";
Expand All @@ -311,19 +313,32 @@ void nft_story_(bool story_start) {
CandidTypeRecord r_in1;
r_in1.append("token_id", CandidTypeText{&token_id});

PromptMo wire_prompt_motoko; // Motoko does not support float32, uses float64
Prompt wire_prompt;
CandidTypeRecord r_in2;
r_in2.append("prompt", CandidTypeText{&wire_prompt.prompt});
r_in2.append("steps", CandidTypeNat64{&wire_prompt.steps});
r_in2.append("temperature", CandidTypeFloat32{&wire_prompt.temperature});
r_in2.append("topp", CandidTypeFloat32{&wire_prompt.topp});
if (from_motoko) {
r_in2.append("temperature",
CandidTypeFloat64{&wire_prompt_motoko.temperature});
r_in2.append("topp", CandidTypeFloat64{&wire_prompt_motoko.topp});
} else {
r_in2.append("temperature", CandidTypeFloat32{&wire_prompt.temperature});
r_in2.append("topp", CandidTypeFloat32{&wire_prompt.topp});
}
r_in2.append("rng_seed", CandidTypeNat64{&wire_prompt.rng_seed});

CandidArgs args;
args.append(r_in1);
args.append(r_in2);
ic_api.from_wire(args);

if (from_motoko) {
wire_prompt.temperature =
static_cast<float>(wire_prompt_motoko.temperature);
wire_prompt.topp = static_cast<float>(wire_prompt_motoko.topp);
}

print_prompt(wire_prompt);

if (story_start or
Expand Down
6 changes: 5 additions & 1 deletion icpp_llama2/src/nft_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ void nft_init() WASM_SYMBOL_EXPORTED("canister_update nft_init");
void nft_metadata() WASM_SYMBOL_EXPORTED("canister_query nft_metadata");
void nft_mint() WASM_SYMBOL_EXPORTED("canister_update nft_mint");
void nft_story_start() WASM_SYMBOL_EXPORTED("canister_update nft_story_start");
void nft_story_start_mo()
WASM_SYMBOL_EXPORTED("canister_update nft_story_start_mo");
void nft_story_continue()
WASM_SYMBOL_EXPORTED("canister_update nft_story_continue");
void nft_story_continue_mo()
WASM_SYMBOL_EXPORTED("canister_update nft_story_continue_mo");
void nft_get_story() WASM_SYMBOL_EXPORTED("canister_query nft_get_story");

// ------------------------------------------------
Expand Down Expand Up @@ -60,6 +64,6 @@ void delete_p_nft_whitelist();
bool nft_is_whitelisted(IC_API &ic_api, bool err_to_wire = true);
void new_p_nft_collection();
void delete_p_nft_collection();
void nft_story_(bool story_start);
void nft_story_(bool story_start, bool from_motoko);
bool nft_exists_(const std::string &token_id);
bool nft_story_exists_(const std::string &token_id);
10 changes: 10 additions & 0 deletions icpp_llama2/src/prompt.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,14 @@ class Prompt {
uint64_t rng_seed{0};
};

// Motoko does not support float32, so we use float64, and then map PromptMo onto Prompt
class PromptMo {
public:
std::string prompt{""};
uint64_t steps{256};
double temperature{1.0};
double topp{0.9};
uint64_t rng_seed{0};
};

void print_prompt(const Prompt &wire_prompt);

0 comments on commit 4225588

Please sign in to comment.