Skip to content

Commit

Permalink
DH-4546/refactoring the code to work with Decimal (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored Aug 29, 2023
1 parent b0ff0fc commit e80941d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
10 changes: 0 additions & 10 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from bson import json_util
from fastapi import HTTPException
from overrides import override
from decimal import Decimal

from dataherald.api import API
from dataherald.api.types import Query
Expand Down Expand Up @@ -119,15 +118,6 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
raise HTTPException(status_code=404, detail=str(e)) from e
generated_answer.confidence_score = confidence_score
generated_answer.exec_time = time.time() - start_generated_answer

if hasattr(generated_answer.sql_query_result, 'columns'):
for col in generated_answer.sql_query_result.columns:
if hasattr(generated_answer.sql_query_result, 'rows'):
for row in generated_answer.sql_query_result.rows:
if isinstance(row[col], Decimal):
# Explicitly convert to float to ensure serialization
row[col] = float(row[col])

nl_query_response_repository = NLQueryResponseRepository(self.storage)
nl_query_response = nl_query_response_repository.insert(generated_answer)
return json.loads(json_util.dumps(nl_query_response))
Expand Down
5 changes: 5 additions & 0 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import date
from decimal import Decimal

from sqlalchemy import text

Expand Down Expand Up @@ -33,6 +34,10 @@ def create_sql_query_status(
type(value) is date
): # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)
Expand Down

0 comments on commit e80941d

Please sign in to comment.