Skip to content

Commit

Permalink
[feat][sdk] Support langchain expr with schema
Browse files Browse the repository at this point in the history
  • Loading branch information
wchuande authored and ketor committed Apr 16, 2024
1 parent fd501f1 commit 641339b
Show file tree
Hide file tree
Showing 21 changed files with 491 additions and 371 deletions.
56 changes: 28 additions & 28 deletions src/benchmark/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = item["id"].GetInt64() + 1;
scalar_value.fields.push_back(field);
Expand All @@ -589,7 +589,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = item["wiki_id"].GetInt64();
scalar_value.fields.push_back(field);
Expand All @@ -598,7 +598,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = item["paragraph_id"].GetInt64();
scalar_value.fields.push_back(field);
Expand All @@ -607,7 +607,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = item["langs"].GetInt64();
scalar_value.fields.push_back(field);
Expand All @@ -616,23 +616,23 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["title"].GetString();
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["title"] = scalar_value;
}
{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["url"].GetString();
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["url"] = scalar_value;
}
{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["text"].GetString();
scalar_value.fields.push_back(field);
Expand Down Expand Up @@ -674,35 +674,35 @@ Dataset::TestEntryPtr Wikipedia2212Dataset::ParseTestData(const rapidjson::Value
const auto& value = kv[1];
if (key == "title") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = value;
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["title"] = scalar_value;
} else if (key == "text") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = value;
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["text"] = scalar_value;
} else if (key == "langs") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = std::stoll(value);
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["langs"] = scalar_value;
} else if (key == "paragraph_id") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = std::stoll(value);
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["paragraph_id"] = scalar_value;
} else if (key == "wiki_id") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = std::stoll(value);
scalar_value.fields.push_back(field);
Expand All @@ -714,14 +714,14 @@ Dataset::TestEntryPtr Wikipedia2212Dataset::ParseTestData(const rapidjson::Value
for (const auto& m : item["filter"].GetObject()) {
if (m.value.IsString()) {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = m.value.GetString();
scalar_value.fields.push_back(field);
vector_with_id.scalar_data[m.name.GetString()] = scalar_value;
} else if (m.value.IsInt64()) {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = m.value.GetInt64();
scalar_value.fields.push_back(field);
Expand Down Expand Up @@ -770,7 +770,7 @@ bool BeirBioasqDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorW

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = std::stoll(item["_id"].GetString()) + 1;
scalar_value.fields.push_back(field);
Expand All @@ -779,7 +779,7 @@ bool BeirBioasqDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorW

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["title"].GetString();
scalar_value.fields.push_back(field);
Expand All @@ -788,7 +788,7 @@ bool BeirBioasqDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorW

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["text"].GetString();
scalar_value.fields.push_back(field);
Expand Down Expand Up @@ -830,14 +830,14 @@ Dataset::TestEntryPtr BeirBioasqDataset::ParseTestData(const rapidjson::Value& o
const auto& value = kv[1];
if (key == "title") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = value;
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["title"] = scalar_value;
} else if (key == "text") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = value;
scalar_value.fields.push_back(field);
Expand All @@ -849,14 +849,14 @@ Dataset::TestEntryPtr BeirBioasqDataset::ParseTestData(const rapidjson::Value& o
for (const auto& m : item["filter"].GetObject()) {
if (m.value.IsString()) {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = m.value.GetString();
scalar_value.fields.push_back(field);
vector_with_id.scalar_data[m.name.GetString()] = scalar_value;
} else if (m.value.IsInt64()) {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = m.value.GetInt64();
scalar_value.fields.push_back(field);
Expand Down Expand Up @@ -916,7 +916,7 @@ bool MiraclDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorWithI

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = id;
scalar_value.fields.push_back(field);
Expand All @@ -925,7 +925,7 @@ bool MiraclDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorWithI

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["title"].GetString();
scalar_value.fields.push_back(field);
Expand All @@ -934,7 +934,7 @@ bool MiraclDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorWithI

{
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = item["text"].GetString();
scalar_value.fields.push_back(field);
Expand Down Expand Up @@ -976,14 +976,14 @@ Dataset::TestEntryPtr MiraclDataset::ParseTestData(const rapidjson::Value& obj)
const auto& value = kv[1];
if (key == "title") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = value;
scalar_value.fields.push_back(field);
vector_with_id.scalar_data["title"] = scalar_value;
} else if (key == "text") {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = value;
scalar_value.fields.push_back(field);
Expand All @@ -995,14 +995,14 @@ Dataset::TestEntryPtr MiraclDataset::ParseTestData(const rapidjson::Value& obj)
for (const auto& m : item["filter"].GetObject()) {
if (m.value.IsString()) {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kString;
scalar_value.type = sdk::Type::kSTRING;
sdk::ScalarField field;
field.string_data = m.value.GetString();
scalar_value.fields.push_back(field);
vector_with_id.scalar_data[m.name.GetString()] = scalar_value;
} else if (m.value.IsInt64()) {
sdk::ScalarValue scalar_value;
scalar_value.type = sdk::ScalarFieldType::kInt64;
scalar_value.type = sdk::Type::kINT64;
sdk::ScalarField field;
field.long_data = m.value.GetInt64();
scalar_value.fields.push_back(field);
Expand Down
Loading

0 comments on commit 641339b

Please sign in to comment.