Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

editoast: add conflict projection endpoint #9003

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ interface PathProperties {
@JvmName("getSpeedLimitProperties")
fun getSpeedLimitProperties(trainTag: String?): DistanceRangeMap<SpeedLimitProperty>

fun getZones(): DistanceRangeMap<ZoneId>

@JvmName("getLength") fun getLength(): Distance

@JvmName("getTrackLocationAtOffset")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ data class PathPropertiesImpl(
}
}

override fun getZones(): DistanceRangeMap<ZoneId> {
return getRangeMapFromUndirected { chunkId ->
val zoneId = infra.getTrackChunkZone(chunkId)
if (zoneId != null) {
val chunkLength = infra.getTrackChunkLength(chunkId).distance
distanceRangeMapOf(listOf(DistanceRangeMap.RangeMapEntry(Distance.ZERO, chunkLength, zoneId)))
} else {
distanceRangeMapOf()
}
}
}

override fun getLength(): Distance {
return chunkPath.endOffset - chunkPath.beginOffset
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ data class PathPropertiesView(
return sliceRangeMap(base.getSpeedLimitProperties(trainTag))
}

override fun getZones(): DistanceRangeMap<ZoneId> {
return sliceRangeMap(base.getZones())
}

override fun getLength(): Distance {
return endOffset - startOffset
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/java/fr/sncf/osrd/cli/ApiServerCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.beust.jcommander.Parameters;
import fr.sncf.osrd.api.*;
import fr.sncf.osrd.api.api_v2.conflicts.ConflictDetectionEndpointV2;
import fr.sncf.osrd.api.api_v2.conflicts.ConflictProjectionEndpoint;
import fr.sncf.osrd.api.api_v2.path_properties.PathPropEndpoint;
import fr.sncf.osrd.api.api_v2.pathfinding.PathfindingBlocksEndpointV2;
import fr.sncf.osrd.api.api_v2.project_signals.SignalProjectionEndpointV2;
Expand Down Expand Up @@ -91,6 +92,7 @@ public int run() {
new FkRegex("/v2/signal_projection", new SignalProjectionEndpointV2(infraManager)),
new FkRegex("/detect_conflicts", new ConflictDetectionEndpoint()),
new FkRegex("/v2/conflict_detection", new ConflictDetectionEndpointV2(infraManager)),
new FkRegex("/conflict_projection", new ConflictProjectionEndpoint(infraManager)),
new FkRegex("/cache_status", new InfraCacheStatusEndpoint(infraManager)),
new FkRegex("/version", new VersionEndpoint()),
new FkRegex("/stdcm", new STDCMEndpoint(infraManager)),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package fr.sncf.osrd.api.api_v2.conflicts

import com.squareup.moshi.Json
import com.squareup.moshi.JsonAdapter
import com.squareup.moshi.Moshi
import com.squareup.moshi.Types
import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory
import fr.sncf.osrd.api.ExceptionHandler
import fr.sncf.osrd.api.InfraManager
import fr.sncf.osrd.api.api_v2.TrackRange
import fr.sncf.osrd.reporting.warnings.DiagnosticRecorderImpl
import fr.sncf.osrd.utils.json.UnitAdapterFactory
import org.takes.Request
import org.takes.Response
import org.takes.Take
import org.takes.rq.RqPrint
import org.takes.rs.RsJson
import org.takes.rs.RsText
import org.takes.rs.RsWithBody
import org.takes.rs.RsWithStatus

class ConflictProjectionEndpoint(private val infraManager: InfraManager) : Take {
override fun act(req: Request?): Response {
return try {
val body = RqPrint(req).printBody()
val request =
conflictProjectionRequestAdapter.fromJson(body)
?: return RsWithStatus(RsText("missing request body"), 400)

if (request.zones.isEmpty()) {
return RsJson(RsWithBody(conflictProjectionResponseAdapter.toJson(listOf())))
}

val recorder = DiagnosticRecorderImpl(false)
val infra = infraManager.getInfra(request.infra, request.expectedVersion, recorder)

for (zoneName in request.zones) {
val zoneId = infra.rawInfra.getZoneFromName(zoneName)
val bounds =
infra.rawInfra.getZoneBounds(zoneId).map {
val trackSection = infra.rawInfra.getDetectorTrackSection(it)
val trackOffset = infra.rawInfra.getDetectorTrackOffset(it)
return 0
}
}

return RsJson(RsWithBody(conflictProjectionResponseAdapter.toJson(listOf())))
} catch (ex: Throwable) {
ExceptionHandler.handle(ex)
}
}
}

class ConflictProjectionRequest(
var infra: String,
@Json(name = "expected_version") var expectedVersion: String,
@Json(name = "path_track_ranges") val pathTrackRanges: Collection<TrackRange>,
val zones: Collection<String>,
)

val conflictProjectionRequestAdapter: JsonAdapter<ConflictProjectionRequest> =
Moshi.Builder()
.addLast(UnitAdapterFactory())
.addLast(KotlinJsonAdapterFactory())
.build()
.adapter(ConflictProjectionRequest::class.java)

val conflictProjectionResponseAdapter: JsonAdapter<Collection<Pair<Int, Int>>> =
Moshi.Builder()
.addLast(UnitAdapterFactory())
.addLast(KotlinJsonAdapterFactory())
.build()
.adapter(
Types.newParameterizedType(Collection::class.java, Pair::class.java, Int::class.java)
)
35 changes: 35 additions & 0 deletions editoast/src/core/conflict_projection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use serde::Deserialize;
use serde::Serialize;
use utoipa::ToSchema;

use crate::core::pathfinding::TrackRange;
use crate::core::{AsCoreRequest, Json};

editoast_common::schemas! {
ConflictProjectionResponse,
}

#[derive(Debug, Serialize)]
pub struct ConflictProjectionRequest {
pub infra: i64,
/// Infrastructure expected version
pub expected_version: String,

pub path_track_ranges: Vec<TrackRange>,
pub zones: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ConflictProjectionResponse {
#[schema(inline)]
pub path_position_ranges: Vec<(u64, u64)>,
}

impl AsCoreRequest<Json<ConflictProjectionResponse>> for ConflictProjectionRequest {
const METHOD: reqwest::Method = reqwest::Method::POST;
const URL_PATH: &'static str = "/v2/conflict_projection";

fn infra_id(&self) -> Option<i64> {
Some(self.infra)
}
}
1 change: 1 addition & 0 deletions editoast/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod conflict_detection;
pub mod conflict_projection;
mod http_client;
pub mod infra_loading;
#[cfg(test)]
Expand Down
92 changes: 92 additions & 0 deletions editoast/src/views/conflicts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use axum::extract::Json;
use axum::extract::State;
use axum::Extension;
use editoast_authz::BuiltinRole;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;

use crate::core::conflict_projection::{ConflictProjectionRequest, ConflictProjectionResponse};
use crate::core::pathfinding::TrackRange as CoreTrackRange;
use crate::core::AsCoreRequest;
use crate::error::Result;
use crate::models::prelude::*;
use crate::views::AuthorizationError;
use crate::views::AuthorizerExt;
use crate::{AppState, Infra};
use editoast_derive::EditoastError;

crate::routes! {
"/conflicts" => {
"/project_path" => project_path,
},
}

editoast_common::schemas! {
ConflictProjectForm,
}

#[derive(Debug, Error, EditoastError)]
#[editoast_error(base_id = "conflicts")]
pub enum ConflictError {
#[error("Infra '{infra_id}', could not be found")]
#[editoast_error(status = 404)]
InfraNotFound { infra_id: i64 },
}

#[derive(Serialize, Deserialize, ToSchema)]
struct ConflictProjectForm {
infra_id: i64,
#[schema(value_type = Vec<TrackRange>)]
path_track_ranges: Vec<CoreTrackRange>,
zones: Vec<String>,
}

#[utoipa::path(
post, path = "",
tag = "conflicts",
request_body = ConflictProjectForm,
responses(
(
status = 200,
body = ConflictProjectionResponse,
description = "Returns a list of conflicts whose track ranges intersect the given path"
),
)
)]
async fn project_path(
State(app_state): State<AppState>,
Extension(authorizer): AuthorizerExt,
Json(ConflictProjectForm {
infra_id,
path_track_ranges,
zones,
}): Json<ConflictProjectForm>,
) -> Result<Json<ConflictProjectionResponse>> {
let authorized = authorizer
.check_roles([BuiltinRole::InfraRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let core = app_state.core_client.clone();

let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
ConflictError::InfraNotFound { infra_id }
})
.await?;

let req = ConflictProjectionRequest {
infra: infra_id,
expected_version: infra.version,
path_track_ranges,
zones,
};
let resp = req.fetch(core.as_ref()).await?;

Ok(Json(resp.into()))
}
2 changes: 2 additions & 0 deletions editoast/src/views/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod authz;
pub mod conflicts;
mod documents;
pub mod electrical_profiles;
pub mod infra;
Expand Down Expand Up @@ -68,6 +69,7 @@ crate::routes! {
"/version/core" => core_version,

&authz,
&conflicts,
&documents,
&electrical_profiles,
&infra,
Expand Down
Loading