Skip to content

Commit

Permalink
Inference returns num_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
icppWorld committed Apr 13, 2024
1 parent af4ffa3 commit b09c6e9
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 9 deletions.
4 changes: 2 additions & 2 deletions icpp_llama2/demo_pytest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dfx start --clean --background
#######################################################################
echo "--------------------------------------------------"
echo "Building the wasm with wasi-sdk"
# icpp build-wasm --to-compile all
icpp build-wasm --to-compile mine
icpp build-wasm --to-compile all
# icpp build-wasm --to-compile mine

#######################################################################
echo " "
Expand Down
215 changes: 215 additions & 0 deletions icpp_llama2/llama2.did
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Candid interface of the canister endpoints
// https://internetcomputer.org/docs/current/references/candid-ref/

type Prompt = record {
prompt : text;
steps : nat64;
temperature : float32;
topp : float32;
rng_seed : nat64;
};

// Motoko does not support float32, so we use float64, and then map PromptMo onto Prompt
type PromptMo = record {
prompt : text;
steps : nat64;
temperature : float64;
topp : float64;
rng_seed : nat64;
};

type Config = record {
dim : int;
hidden_dim : int;
n_layers : int;
n_heads : int;
n_kv_heads : int;
vocab_size : int;
seq_len : int;
};

// ----------------------------------------------------------
// New approach to endpoint return values:
// -> wrap a record in a Result

// --
type ApiError = variant {
InvalidId;
Other : text;
StatusCode : nat16;
ZeroAddress;
};

// --
// Returned by several endpoints.
// HTTPS status code wrapped in a Record wrapped in a Result
type StatusCodeRecordResult = variant {
Err : ApiError;
Ok : StatusCodeRecord;
};
type StatusCodeRecord = record { status_code : nat16 };

// --
// Returned by 'set_canister_mode'
type CanisterModeRecordResult = variant {
Err : ApiError;
Ok : CanisterModeRecord;
};
type CanisterModeRecord = record { canister_mode : text };

// --
// Returned by 'inference', 'nft_story_start', 'nft_story_continue'
// Section of a story, generated by a single inference call
type InferenceRecordResult = variant {
Err : ApiError;
Ok : InferenceRecord;
};
type InferenceRecord = record { inference : text };

// --
// A story, from beginning, build from multiple inference calls
type StoryRecordResult = variant {
Err : ApiError;
Ok : StoryRecord;
};
type StoryRecord = record { story : text };

// --
// Metadata for an NFT collection
type NFTCollectionRecordResult = variant {
Err : ApiError;
Ok : NFTCollectionRecord;
};
type NFTCollectionRecord = record {
nft_supply_cap : nat64;
nft_total_supply : nat64;
nft_symbol : text;
nft_name : text;
nft_description : text;
};

// --
// Returned by 'get_users'
type UsersRecordResult = variant {
Err : ApiError;
Ok : UsersRecord;
};
type UsersRecord = record {
user_count : nat64;
user_ids : vec text;
};

// --
// Returned by 'get_user_metadata'

type UserMetadataRecordResult = variant {
Err : ApiError;
Ok : UserMetadataRecord;
};
type UserMetadataRecord = record {
chats_start_time : vec nat64;
chats_total_steps : vec nat64;
};

// ----------------------------------------------------------

type NFTWhitelistRecord = record {
id : principal;
description : text;
};

type NFT = record {
token_id : text;
};

// --------------------------------------------------------------------------------
// HTTP Gateway Protocol
// https://internetcomputer.org/docs/current/references/http-gateway-protocol-spec#canister-http-interface
// https://internetcomputer.org/docs/current/references/http-gateway-protocol-spec
// https://internetcomputer.org/docs/current/references/ic-interface-spec/#ic-candid

type HeaderField = record { text; text };

type HttpRequest = record {
method : text;
url : text;
headers : vec HeaderField;
body : blob;
certificate_version : opt nat16;
};

// type HttpUpdateRequest = record {
// method: text;
// url: text;
// headers: vec HeaderField;
// body: blob;
// };

type HttpResponse = record {
status_code : nat16;
headers : vec HeaderField;
body : blob;
upgrade : opt bool;
// streaming_strategy: opt StreamingStrategy;
};

/* StreamingStrategy is NOT YET SUPPORTED
// Each canister that uses the streaming feature gets to choose their concrete
// type; the HTTP Gateway will treat it as an opaque value that is only fed to
// the callback method

type StreamingToken = === application-specific type ===

type StreamingCallbackHttpResponse = record {
body: blob;
token: opt StreamingToken;
};

type StreamingStrategy = variant {
Callback: record {
callback: func (StreamingToken) -> (opt StreamingCallbackHttpResponse) query;
token: StreamingToken;
};
};
*/

service : {
// canister endpoints
canister_init : () -> ();
set_canister_mode : (text) -> (StatusCodeRecordResult);
health : () -> (StatusCodeRecordResult) query;
ready : () -> (StatusCodeRecordResult) query;

// LLM initialization endpoints
reset_model : () -> (StatusCodeRecordResult);
reset_tokenizer : () -> (StatusCodeRecordResult);
upload_model_bytes_chunk : (vec nat8) -> (StatusCodeRecordResult);
upload_tokenizer_bytes_chunk : (vec nat8) -> (StatusCodeRecordResult);
initialize : () -> (StatusCodeRecordResult);
get_model_config : () -> (Config) query;

// Chat endpoints for canister_mode=chat-principal
new_chat : () -> (StatusCodeRecordResult);
inference : (Prompt) -> (InferenceRecordResult);
inference_mo : (PromptMo) -> (InferenceRecordResult);

// admin endpoints
whoami : () -> (text) query;
get_users : () -> (UsersRecordResult) query;
get_user_metadata : (text) -> (UserMetadataRecordResult) query;

// http endpoints
http_request : (request : HttpRequest) -> (HttpResponse) query;

// nft endpoints (for canister_mode=nft-ordinal)
nft_whitelist : (NFTWhitelistRecord) -> (StatusCodeRecordResult);
nft_ami_whitelisted : () -> (StatusCodeRecordResult);
nft_init : (NFTCollectionRecord) -> (StatusCodeRecordResult);
nft_metadata : () -> (NFTCollectionRecordResult) query;
nft_mint : (NFT) -> (StatusCodeRecordResult);
nft_story_start : (NFT, Prompt) -> (InferenceRecordResult);
nft_story_start_mo : (NFT, PromptMo) -> (InferenceRecordResult);
nft_story_continue : (NFT, Prompt) -> (InferenceRecordResult);
nft_story_continue_mo : (NFT, PromptMo) -> (InferenceRecordResult);
nft_get_story : (NFT) -> (StoryRecordResult) query;
};
9 changes: 7 additions & 2 deletions icpp_llama2/native/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ int main() {
std::array<uint64_t, 2> rng_seed = {0, 0};

std::array<std::string, 2> generated_tokens = {"", ""};
std::array<uint64_t, 2> num_tokens = {0, 0};
std::array<std::string, 2> story = {"", ""};

for (int i = 0; i < 10; i++) {
for (int j = 0; j < 2; j++) {

Expand All @@ -464,6 +466,7 @@ int main() {
CandidTypeRecord inference_record;
inference_record.append("inference",
CandidTypeText{&generated_tokens[j]});
inference_record.append("num_tokens", CandidTypeNat64{&num_tokens[j]});
std::string err_text;
CandidTypeVariant v_out;
v_out.append("Ok", inference_record);
Expand Down Expand Up @@ -521,6 +524,7 @@ int main() {
uint64_t rng_seed = 0;

std::string generated_tokens = "";
uint64_t num_tokens = 0;
std::string story = "";
for (int i = 0; i < 100; i++) {
CandidTypeRecord r_in;
Expand All @@ -537,6 +541,7 @@ int main() {

CandidTypeRecord inference_record;
inference_record.append("inference", CandidTypeText{&generated_tokens});
inference_record.append("num_tokens", CandidTypeNat64{&num_tokens});
std::string err_text;
CandidTypeVariant v_out;
v_out.append("Ok", inference_record);
Expand Down Expand Up @@ -571,11 +576,11 @@ int main() {

// With temperature=0.0: greedy argmax sampling -> the story will be the same every time
// '(record {prompt = "" : text; steps = 100 : nat64; temperature = 0.0 : float32; topp = 1.0 : float32; rng_seed = 0 : nat64;})'
// -> '(variant { Ok = record { inference = "...story..." : text;} })'
// -> '(variant { Ok = record { inference = "...story..." : text; num_tokens = 101 } })'
expected_response = "-to-do-B-";
if (model_to_use == 1) {
expected_response =
"4449444c026c01d9b3b9980f716b01bc8a0100010100fd014f6e63652075706f6e20612074696d652c207468657265207761732061206c6974746c65206769726c206e616d6564204c696c792e20536865206c6f76656420746f20706c6179206f75747369646520696e20746865207061726b2e204f6e65206461792c20736865207361772061206269672c207265642062616c6c2e205368652077616e74656420746f20706c617920776974682069742c206275742069742077617320746f6f20686967682e0a4c696c792773206d6f6d20736169642c20224c696c792c206c6574277320676f20746f20746865207061726b2e22204c696c79207761732073616420616e64206469646e2774206b6e6f772077";
"4449444c026c02f3feb4990678d9b3b9980f716b01bc8a01000101006500000000000000fd014f6e63652075706f6e20612074696d652c207468657265207761732061206c6974746c65206769726c206e616d6564204c696c792e20536865206c6f76656420746f20706c6179206f75747369646520696e20746865207061726b2e204f6e65206461792c20736865207361772061206269672c207265642062616c6c2e205368652077616e74656420746f20706c617920776974682069742c206275742069742077617320746f6f20686967682e0a4c696c792773206d6f6d20736169642c20224c696c792c206c6574277320676f20746f20746865207061726b2e22204c696c79207761732073616420616e64206469646e2774206b6e6f772077";
} else if (model_to_use == 2) {
} else if (model_to_use == 3) {
} else if (model_to_use == 4) {
Expand Down
2 changes: 1 addition & 1 deletion icpp_llama2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-r scripts/requirements.txt
icpp-pro==3.13.0
icpp-pro==3.15.2
ic-py==1.0.1
requests
3 changes: 0 additions & 3 deletions icpp_llama2/src/canister.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ bool is_canister_owner(IC_API &ic_api, bool err_to_wire) {
CandidTypePrincipal caller = ic_api.get_caller();
if (caller.get_text() == *p_canister_owner_principal) return true;
else {
IC_API::debug_print(std::string(__func__) +
": ERROR - caller is not the owner.");

if (err_to_wire) {
std::string error_msg = "Access Denied";
ic_api.to_wire(CandidTypeVariant{
Expand Down
1 change: 1 addition & 0 deletions icpp_llama2/src/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ void inference_(bool from_motoko) {
// Send the generated response to the wire
CandidTypeRecord inference_record;
inference_record.append("inference", CandidTypeText{output});
inference_record.append("num_tokens", CandidTypeNat64{chat->total_steps});
ic_api.to_wire(CandidTypeVariant{"Ok", CandidTypeRecord{inference_record}});
}

Expand Down
5 changes: 4 additions & 1 deletion icpp_llama2/src/llama2.did
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ type InferenceRecordResult = variant {
Err : ApiError;
Ok : InferenceRecord;
};
type InferenceRecord = record { inference : text };
type InferenceRecord = record {
inference : text;
num_tokens : nat64;
};

// --
// A story, from beginning, build from multiple inference calls
Expand Down

0 comments on commit b09c6e9

Please sign in to comment.