Skip to content

Commit

Permalink
feat: limiting the size of query results to Dashboard (#3901)
Browse files Browse the repository at this point in the history
* feat: limiting the size of query results to Dashboard

* optimize code

* fix by cr

* fix integration tests error

* remove RequestSource::parse

* refactor: sql query params

* fix: unit test

---------

Co-authored-by: tison <wander4096@gmail.com>
  • Loading branch information
realtaobo and tisonkun committed May 14, 2024
1 parent e15294d commit 494ce65
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 26 deletions.
34 changes: 34 additions & 0 deletions src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ impl From<SchemaRef> for OutputSchema {
pub struct HttpRecordsOutput {
schema: OutputSchema,
rows: Vec<Vec<Value>>,
// total_rows is equal to rows.len() in most cases,
// the Dashboard query result may be truncated, so we need to return the total_rows.
#[serde(default)]
total_rows: usize,

// plan level execution metrics
#[serde(skip_serializing_if = "HashMap::is_empty")]
Expand Down Expand Up @@ -224,6 +228,7 @@ impl HttpRecordsOutput {
Ok(HttpRecordsOutput {
schema: OutputSchema::from(schema),
rows: vec![],
total_rows: 0,
metrics: Default::default(),
})
} else {
Expand All @@ -244,6 +249,7 @@ impl HttpRecordsOutput {

Ok(HttpRecordsOutput {
schema: OutputSchema::from(schema),
total_rows: rows.len(),
rows,
metrics: Default::default(),
})
Expand Down Expand Up @@ -357,6 +363,34 @@ impl HttpResponse {
HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
}
}

pub fn with_limit(self, limit: usize) -> Self {
match self {
HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
HttpResponse::Table(resp) => resp.with_limit(limit).into(),
HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
_ => self,
}
}
}

pub fn process_with_limit(
mut outputs: Vec<GreptimeQueryOutput>,
limit: usize,
) -> Vec<GreptimeQueryOutput> {
outputs
.drain(..)
.map(|data| match data {
GreptimeQueryOutput::Records(mut records) => {
if records.rows.len() > limit {
records.rows.truncate(limit);
records.total_rows = limit;
}
GreptimeQueryOutput::Records(records)
}
_ => data,
})
.collect()
}

impl IntoResponse for HttpResponse {
Expand Down
6 changes: 6 additions & 0 deletions src/servers/src/http/csv_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use mime_guess::mime;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::process_with_limit;
use crate::http::error_result::ErrorResponse;
use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
use crate::http::{handler, GreptimeQueryOutput, HttpResponse, ResponseFormat};
Expand Down Expand Up @@ -65,6 +66,11 @@ impl CsvResponse {
pub fn execution_time_ms(&self) -> u64 {
self.execution_time_ms
}

pub fn with_limit(mut self, limit: usize) -> Self {
self.output = process_with_limit(self.output, limit);
self
}
}

impl IntoResponse for CsvResponse {
Expand Down
6 changes: 6 additions & 0 deletions src/servers/src/http/greptime_result_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::header::GREPTIME_DB_HEADER_METRICS;
use super::process_with_limit;
use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
use crate::http::{handler, GreptimeQueryOutput, HttpResponse, ResponseFormat};

Expand Down Expand Up @@ -62,6 +63,11 @@ impl GreptimedbV1Response {
pub fn execution_time_ms(&self) -> u64 {
self.execution_time_ms
}

pub fn with_limit(mut self, limit: usize) -> Self {
self.output = process_with_limit(self.output, limit);
self
}
}

impl IntoResponse for GreptimedbV1Response {
Expand Down
8 changes: 6 additions & 2 deletions src/servers/src/http/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub struct SqlQuery {
// specified time precision. Maybe greptimedb format can support this
// param too.
pub epoch: Option<String>,
pub limit: Option<usize>,
}

/// Handler to execute sql
Expand Down Expand Up @@ -98,7 +99,7 @@ pub async fn sql(
if let Some((status, msg)) = validate_schema(sql_handler.clone(), query_ctx.clone()).await {
Err((status, msg))
} else {
Ok(sql_handler.do_query(sql, query_ctx).await)
Ok(sql_handler.do_query(sql, query_ctx.clone()).await)
}
} else {
Err((
Expand All @@ -117,14 +118,17 @@ pub async fn sql(
Ok(outputs) => outputs,
};

let resp = match format {
let mut resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::Table => TableResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, epoch).await,
};

if let Some(limit) = query_params.limit {
resp = resp.with_limit(limit);
}
resp.with_execution_time(start.elapsed().as_millis() as u64)
}

Expand Down
6 changes: 6 additions & 0 deletions src/servers/src/http/table_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use mime_guess::mime;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::process_with_limit;
use crate::http::error_result::ErrorResponse;
use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
use crate::http::{handler, GreptimeQueryOutput, HttpResponse, ResponseFormat};
Expand Down Expand Up @@ -66,6 +67,11 @@ impl TableResponse {
pub fn execution_time_ms(&self) -> u64 {
self.execution_time_ms
}

pub fn with_limit(mut self, limit: usize) -> Self {
self.output = process_with_limit(self.output, limit);
self
}
}

impl Display for TableResponse {
Expand Down
74 changes: 60 additions & 14 deletions src/servers/tests/http/http_handler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use headers::HeaderValue;
use http_body::combinators::UnsyncBoxBody;
use hyper::Response;
use mime_guess::mime;
use servers::http::GreptimeQueryOutput::Records;
use servers::http::{
handler as http_handler, script as script_handler, ApiState, GreptimeOptionsConfigState,
GreptimeQueryOutput, HttpResponse,
Expand All @@ -48,10 +49,8 @@ async fn test_sql_not_provided() {

for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] {
let query = http_handler::SqlQuery {
db: None,
sql: None,
format: Some(format.to_string()),
epoch: None,
..Default::default()
};

let HttpResponse::Error(resp) = http_handler::sql(
Expand Down Expand Up @@ -82,8 +81,9 @@ async fn test_sql_output_rows() {
script_handler: None,
};

let query_sql = "select sum(uint32s) from numbers limit 20";
for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] {
let query = create_query(format);
let query = create_query(format, query_sql, None);
let json = http_handler::sql(
State(api_state.clone()),
query,
Expand Down Expand Up @@ -112,7 +112,8 @@ async fn test_sql_output_rows() {
[
4950
]
]
],
"total_rows": 1
}"#
);
}
Expand Down Expand Up @@ -176,6 +177,49 @@ async fn test_sql_output_rows() {
}
}

#[tokio::test]
async fn test_dashboard_sql_limit() {
let sql_handler = create_testing_sql_query_handler(MemTable::specified_numbers_table(2000));
let ctx = QueryContext::arc();
ctx.set_current_user(Some(auth::userinfo_by_name(None)));
let api_state = ApiState {
sql_handler,
script_handler: None,
};
for format in ["greptimedb_v1", "csv", "table"] {
let query = create_query(format, "select * from numbers", Some(1000));
let sql_response = http_handler::sql(
State(api_state.clone()),
query,
axum::Extension(ctx.clone()),
Form(http_handler::SqlQuery::default()),
)
.await;

match sql_response {
HttpResponse::GreptimedbV1(resp) => match resp.output().first().unwrap() {
Records(records) => {
assert_eq!(records.num_rows(), 1000);
}
_ => unreachable!(),
},
HttpResponse::Csv(resp) => match resp.output().first().unwrap() {
Records(records) => {
assert_eq!(records.num_rows(), 1000);
}
_ => unreachable!(),
},
HttpResponse::Table(resp) => match resp.output().first().unwrap() {
Records(records) => {
assert_eq!(records.num_rows(), 1000);
}
_ => unreachable!(),
},
_ => unreachable!(),
}
}
}

#[tokio::test]
async fn test_sql_form() {
common_telemetry::init_default_ut_logging();
Expand Down Expand Up @@ -219,7 +263,8 @@ async fn test_sql_form() {
[
4950
]
]
],
"total_rows": 1
}"#
);
}
Expand Down Expand Up @@ -393,7 +438,8 @@ def test(n) -> vector[i64]:
[
4
]
]
],
"total_rows": 5
}"#
);
}
Expand Down Expand Up @@ -460,7 +506,8 @@ def test(n, **params) -> vector[i64]:
[
46
]
]
],
"total_rows": 5
}"#
);
}
Expand All @@ -484,21 +531,20 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
})
}

fn create_query(format: &str) -> Query<http_handler::SqlQuery> {
fn create_query(format: &str, sql: &str, limit: Option<usize>) -> Query<http_handler::SqlQuery> {
Query(http_handler::SqlQuery {
sql: Some("select sum(uint32s) from numbers limit 20".to_string()),
db: None,
sql: Some(sql.to_string()),
format: Some(format.to_string()),
epoch: None,
limit,
..Default::default()
})
}

fn create_form(format: &str) -> Form<http_handler::SqlQuery> {
Form(http_handler::SqlQuery {
sql: Some("select sum(uint32s) from numbers limit 20".to_string()),
db: None,
format: Some(format.to_string()),
epoch: None,
..Default::default()
})
}

Expand Down
6 changes: 5 additions & 1 deletion src/table/src/test_util/memtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,18 @@ impl MemTable {
/// Creates a 1 column 100 rows table, with table name "numbers", column name "uint32s" and
/// column type "uint32". Column data increased from 0 to 100.
pub fn default_numbers_table() -> TableRef {
Self::specified_numbers_table(100)
}

pub fn specified_numbers_table(rows: u32) -> TableRef {
let column_schemas = vec![ColumnSchema::new(
"uint32s",
ConcreteDataType::uint32_datatype(),
true,
)];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![Arc::new(UInt32Vector::from_slice(
(0..100).collect::<Vec<_>>(),
(0..rows).collect::<Vec<_>>(),
))];
let recordbatch = RecordBatch::new(schema, columns).unwrap();
MemTable::table("numbers", recordbatch)
Expand Down
Loading

0 comments on commit 494ce65

Please sign in to comment.