Skip to content

Commit

Permalink
5070 create new gollm endpoint that takes in a list of latex equation…
Browse files Browse the repository at this point in the history
…s a returns cleaned up ones (#5145)
  • Loading branch information
dgauldie authored Oct 15, 2024
1 parent 895df79 commit e5cd64a
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 20 deletions.
4 changes: 4 additions & 0 deletions packages/gollm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class ModelCompareModel(BaseModel):
amrs: List[str] # expects AMRs to be a stringified JSON object


class EquationsCleanup(BaseModel):
equations: List[str]


class EquationsFromImage(BaseModel):
image: str # expects a base64 encoded image

Expand Down
24 changes: 24 additions & 0 deletions packages/gollm/gollm_openai/prompts/equations_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
EQUATIONS_CLEANUP_PROMPT = """
You are a helpful agent designed to reformat latex equations based on a supplied style guide.
The style guide will contain a list of rules that you should follow when reformatting the equations. You should reformat the equations to match the style guide as closely as possible.
Do not respond in full sentences; only create a JSON object that satisfies the JSON schema specified in the response format.
Use the following style guide to ensure that your LaTeX equations are correctly formatted:
--- STYLE GUIDE START ---
{style_guide}
--- STYLE GUIDE END ---
The equations that you need to reformat are as follows:
--- EQUATIONS START ---
{equations}
--- EQUATIONS END ---
Answer:
"""
27 changes: 11 additions & 16 deletions packages/gollm/gollm_openai/prompts/latex_style_guide.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
LATEXT_STYLE_GUIDE = """
1) Derivatives must be written in Leibniz notation (for example, \\frac{d X}{d t}). Equations that are written in other notations, like Newton or Lagrange, should be converted.
1) Derivatives must be written in Leibniz notation (for example, '\\frac{d X}{d t}'). Equations that are written in other notations, like Newton or Lagrange, should be converted.
2) First-order derivative must be on the left of the equal sign
3) Use whitespace to indicate multiplication
a) "*" is optional but probably should be avoided
4) "(t)" is optional and probably should be avoided
5) Avoid superscripts and LaTeX superscripts "^", particularly to denote sub-variables
6) Subscripts using LaTeX "_" are permitted
a) Ensure that all characters used in the subscript are surrounded by a pair of curly brackets "{...}"
7) Avoid mathematical constants like pi or Euler's number
a) Replace them as floats with 3 decimal places of precision
8) Avoid parentheses
9) Avoid capital sigma and pi notations for summation and product
10) Avoid non-ASCII characters when possible
11) Avoid using homoglyphs
12) Avoid words or multi-character names for variables and names
a) Use camel case to express multi-word or multi-character names
13) Do not use \\cdot for multiplication. Use whitespace instead.
3) the use of '(t)' should be avoided
4) Avoid superscripts and LaTeX superscripts '^', particularly to denote sub-variables
5) Subscripts using LaTeX '_' are permitted. However, Ensure that all characters used in the subscript are surrounded by a pair of curly brackets '{...}'
6) Avoid parentheses
7) Avoid capital sigma and pi notations for summation and product
8) Avoid non-ASCII characters when possible
9) Avoid using homoglyphs
10) Avoid words or multi-character names for variables and names. Use camel case to express multi-word or multi-character names
11) Do not use '\\cdot' or '*' to indicate multiplication. Use whitespace instead.
12) Avoid using notation for mathematical constants like 'e' and 'pi'. Use their actual values up to 3 decimal places instead.
"""
46 changes: 44 additions & 2 deletions packages/gollm/gollm_openai/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CONFIGURE_FROM_DATASET_MATRIX_PROMPT
)
from gollm_openai.prompts.config_from_document import CONFIGURE_FROM_DOCUMENT_PROMPT
from gollm_openai.prompts.equations_cleanup import EQUATIONS_CLEANUP_PROMPT
from gollm_openai.prompts.equations_from_image import EQUATIONS_FROM_IMAGE_PROMPT
from gollm_openai.prompts.general_instruction import GENERAL_INSTRUCTION_PROMPT
from gollm_openai.prompts.interventions_from_document import INTERVENTIONS_FROM_DOCUMENT_PROMPT
Expand Down Expand Up @@ -59,6 +60,47 @@ def get_image_format_string(image_format: str) -> str:
return format_strings.get(image_format.lower())


def equations_cleanup(equations: List[str]) -> dict:
print("Reformatting equations...")

print("Uploading and validating equations schema...")
config_path = os.path.join(SCRIPT_DIR, 'schemas', 'equations.json')
with open(config_path, 'r') as config_file:
response_schema = json.load(config_file)
validate_schema(response_schema)

print("Building prompt to reformat equations...")
prompt = EQUATIONS_CLEANUP_PROMPT.format(
style_guide=LATEXT_STYLE_GUIDE,
equations="\n".join(equations)
)

client = OpenAI()
output = client.chat.completions.create(
model="gpt-4o-mini",
top_p=1,
frequency_penalty=0,
presence_penalty=0,
temperature=0,
seed=123,
max_tokens=1024,
response_format={
"type": "json_schema",
"json_schema": {
"name": "equations",
"schema": response_schema
}
},
messages=[
{"role": "user", "content": prompt},
]
)
print("Received response from OpenAI API. Formatting response to work with HMI...")
output_json = json.loads(output.choices[0].message.content)

return output_json


def equations_from_image(image: str) -> dict:
print("Translating equations from image...")

Expand Down Expand Up @@ -340,7 +382,7 @@ def embedding_chain(text: str) -> List:
return output.data[0].embedding


def model_config_from_dataset(amr: str, dataset: List[str], matrix: str) -> str:
def model_config_from_dataset(amr: str, dataset: List[str], matrix: str) -> dict:
print("Extracting datasets...")
dataset_text = os.linesep.join(dataset)

Expand Down Expand Up @@ -392,7 +434,7 @@ def model_config_from_dataset(amr: str, dataset: List[str], matrix: str) -> str:
return model_config_adapter(output_json)


def compare_models(amrs: List[str]) -> str:
def compare_models(amrs: List[str]) -> dict:
print("Comparing models...")

print("Building prompt to compare models...")
Expand Down
3 changes: 2 additions & 1 deletion packages/gollm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
"gollm:configure_model_from_document=tasks.configure_model_from_document:main",
"gollm:embedding=tasks.embedding:main",
"gollm:enrich_amr=tasks.enrich_amr:main",
"gollm:equations_cleanup=tasks.equations_cleanup:main",
"gollm:equations_from_image=tasks.equations_from_image:main",
"gollm:generate_response=tasks.generate_response:main",
"gollm:generate_summary=tasks.generate_summary:main",
"gollm:interventions_from_document=tasks.interventions_from_document:main",
"gollm:model_card=tasks.model_card:main",
"gollm:model_card=tasks.model_card:main"
],
},
python_requires=">=3.11",
Expand Down
40 changes: 40 additions & 0 deletions packages/gollm/tasks/equations_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import sys
from entities import EquationsCleanup
from gollm_openai.tool_utils import equations_cleanup

from taskrunner import TaskRunnerInterface


def cleanup():
pass


def main():
exitCode = 0
try:
taskrunner = TaskRunnerInterface(description="Reformat equations based on style guide")
taskrunner.on_cancellation(cleanup)

input_dict = taskrunner.read_input_dict_with_timeout()

taskrunner.log("Creating EquationsCleanup from input")
input_model = EquationsCleanup(**input_dict)

taskrunner.log("Sending request to OpenAI API")
response = equations_cleanup(equations=input_model.equations)
taskrunner.log("Received response from OpenAI API")

taskrunner.write_output_dict_with_timeout({"response": response})

except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
sys.stderr.flush()
exitCode = 1

taskrunner.log("Shutting down")
taskrunner.shutdown()
sys.exit(exitCode)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import jakarta.annotation.PostConstruct;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -60,6 +61,7 @@
import software.uncharted.terarium.hmiserver.models.extractionservice.ExtractionResponse;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.proxies.mit.MitProxy;
import software.uncharted.terarium.hmiserver.proxies.skema.SkemaUnifiedProxy;
import software.uncharted.terarium.hmiserver.security.Roles;
Expand All @@ -73,6 +75,7 @@
import software.uncharted.terarium.hmiserver.service.data.ProvenanceSearchService;
import software.uncharted.terarium.hmiserver.service.data.ProvenanceService;
import software.uncharted.terarium.hmiserver.service.tasks.EnrichAmrResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.EquationsCleanupResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService.TaskMode;
import software.uncharted.terarium.hmiserver.utils.ByteMultipartFile;
Expand Down Expand Up @@ -106,11 +109,18 @@ public class KnowledgeController {
final ProjectService projectService;
final CurrentUserService currentUserService;

private final EquationsCleanupResponseHandler equationsCleanupResponseHandler;

final Messages messages;

@Value("${openai-api-key:}")
String OPENAI_API_KEY;

@PostConstruct
void init() {
taskService.addResponseHandler(equationsCleanupResponseHandler);
}

private void enrichModel(
final UUID projectId,
final UUID documentId,
Expand Down Expand Up @@ -236,9 +246,47 @@ public ResponseEntity<UUID> equationsToModel(
}
}

// Cleanup equations from the request
List<String> equations = new ArrayList<>();
if (req.get("equations") != null) {
for (final JsonNode equation : req.get("equations")) {
equations.add(equation.asText());
}
}
TaskRequest cleanupReq = cleanupEquationsTaskRequest(projectId, equations);
TaskResponse cleanupResp = null;
try {
cleanupResp = taskService.runTask(TaskMode.SYNC, cleanupReq);
} catch (final JsonProcessingException e) {
log.warn("Unable to clean-up equations due to a JsonProcessingException. Reverting to original equations.", e);
} catch (final TimeoutException e) {
log.warn("Unable to clean-up equations due to a TimeoutException. Reverting to original equations.", e);
} catch (final InterruptedException e) {
log.warn("Unable to clean-up equations due to a InterruptedException. Reverting to original equations.", e);
} catch (final ExecutionException e) {
log.warn("Unable to clean-up equations due to a ExecutionException. Reverting to original equations.", e);
}

// get the equations from the cleanup response, or use the original equations
JsonNode equationsReq = req.get("equations");
if (cleanupResp != null && cleanupResp.getOutput() != null) {
try {
JsonNode output = mapper.readValue(cleanupResp.getOutput(), JsonNode.class);
if (output.get("response") != null && output.get("response").get("equations") != null) {
equationsReq = output.get("response").get("equations");
}
} catch (IOException e) {
log.warn("Unable to retrive cleaned-up equations from GoLLM response. Reverting to original equations.", e);
}
}

// Create a new request with the cleaned-up equations, so that we don't modify the original request.
JsonNode newReq = req.deepCopy();
((ObjectNode) newReq).set("equations", equationsReq);

// Get an AMR from Skema Unified Service
try {
responseAMR = skemaUnifiedProxy.consolidatedEquationsToAMR(req).getBody();
responseAMR = skemaUnifiedProxy.consolidatedEquationsToAMR(newReq).getBody();
if (responseAMR == null) {
log.warn("Skema Unified Service did not return a valid AMR based on the provided equations");
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("skema.bad-equations"));
Expand Down Expand Up @@ -840,4 +888,30 @@ private ResponseStatusException handleSkemaFeignException(final FeignException e
);
return new ResponseStatusException(httpStatus, messages.get("generic.unknown"));
}

private TaskRequest cleanupEquationsTaskRequest(UUID projectId, List<String> equations) {
final EquationsCleanupResponseHandler.Input input = new EquationsCleanupResponseHandler.Input();
input.setEquations(equations);

// Create the task
final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
req.setScript(EquationsCleanupResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

try {
req.setInput(mapper.writeValueAsBytes(input));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setProjectId(projectId);

final EquationsCleanupResponseHandler.Properties props = new EquationsCleanupResponseHandler.Properties();
props.setProjectId(projectId);
req.setAdditionalProperties(props);

return req;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.List;
import java.util.UUID;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
@Slf4j
public class EquationsCleanupResponseHandler extends TaskResponseHandler {

public static final String NAME = "gollm:equations_cleanup";

@Override
public String getName() {
return NAME;
}

@Data
public static class Input {

List<String> equations;
}

@Data
public static class Properties {

UUID projectId;
}

@Data
public static class Response {

JsonNode response;
}
}

0 comments on commit e5cd64a

Please sign in to comment.