Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/Limmen/csle
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Aug 17, 2023
2 parents 9608040 + b568282 commit 2e28a83
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def to_dict(self) -> Dict[str, Any]:
:return: a dict representation of the object
"""
d = {}
d: Dict[str, Any] = {}
d["port_id"] = self.port_id
d["protocol"] = self.protocol
d["status"] = self.status
d["service_name"] = self.service_name
d["http_enum"] = self.http_enum.to_dict()
d["http_grep"] = self.http_grep.to_dict()
d["vulscan"] = self.vulscan.to_dict()
d["http_enum"] = self.http_enum.to_dict() if self.http_enum is not None else self.http_enum
d["http_grep"] = self.http_grep.to_dict() if self.http_grep is not None else self.http_grep
d["vulscan"] = self.vulscan.to_dict() if self.vulscan is not None else self.vulscan
d["service_version"] = self.service_version
d["service_fp"] = self.service_fp
return d
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class EmulationPortObservationState(JSONSerializable):
DTO Representation a port observation in the emulation
"""

def __init__(self, port: int, open: bool, service: int, protocol: TransportProtocol, http_enum: str = "",
def __init__(self, port: int, open: bool, service: str, protocol: TransportProtocol, http_enum: str = "",
http_grep: str = "", vulscan: str = "", version: str = "", fingerprint: str = ""):
"""
Initializes the DTO
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Union, Any
from typing import List, Dict, Union, Any, Optional
import numpy as np
from numpy.typing import NDArray
import torch
from stable_baselines3 import DQN
from csle_common.dao.training.policy import Policy
Expand All @@ -18,7 +19,7 @@ class DQNPolicy(Policy):
A neural network policy learned with DQN
"""

def __init__(self, model, simulation_name: str, save_path: str, player_type: PlayerType, states: List[State],
def __init__(self, model: Optional[DQN], simulation_name: str, save_path: str, player_type: PlayerType, states: List[State],
actions: List[Action], experiment_config: ExperimentConfig, avg_R: float):
"""
Initializes the policy
Expand Down Expand Up @@ -49,23 +50,27 @@ def __init__(self, model, simulation_name: str, save_path: str, player_type: Pla
self.avg_R = avg_R
self.policy_type = PolicyType.DQN

def action(self, o: List[float]) -> Union[int, List[int], np.ndarray[Any, Any]]:
def action(self, o: List[float]) -> NDArray[Any]:
"""
Multi-threshold stopping policy
:param o: the current observation
:return: the selected action
"""
a, _ = self.model.predict(np.array(o), deterministic=False)
if self.model is None:
raise ValueError("The model i None")
a = self.model.predict(np.array(o), deterministic=False)[0]
return a

def probability(self, o: List[float], a) -> Union[int, List[int], np.ndarray[Any, Any]]:
def probability(self, o: List[float], a) -> int:
"""
Multi-threshold stopping policy
:param o: the current observation
:return: the selected action
"""
if self.model is None:
raise ValueError("The model is None")
actions = self.model.policy.forward(obs=torch.tensor(o).to(self.model.device))
action = actions[0]
if action == a:
Expand Down Expand Up @@ -142,6 +147,8 @@ def _get_attacker_stopping_dist(self, obs) -> List[float]:
:param obs: the observation to condition on
:return: the conditional ation distribution
"""
if self.model is None:
raise ValueError("The model is None")
obs = np.array([obs])
actions = self.model.policy.forward(obs=torch.tensor(obs).to(self.model.device))
action = actions[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Union
from abc import abstractmethod
from numpy.typing import NDArray
from csle_common.dao.training.agent_type import AgentType
from csle_common.dao.training.player_type import PlayerType
from csle_base.json_serializable import JSONSerializable
Expand All @@ -21,7 +22,7 @@ def __init__(self, agent_type: AgentType, player_type: PlayerType) -> None:
self.player_type = player_type

@abstractmethod
def action(self, o: Any) -> Any:
def action(self, o: Any) -> NDArray[Any]:
"""
Calculates the next action
Expand Down Expand Up @@ -61,7 +62,7 @@ def from_dict(d: Dict) -> "Policy":
pass

@abstractmethod
def probability(self, o: Any, a: int) -> float:
def probability(self, o: Any, a: int) -> Union[int, float]:
"""
Calculates the probability of a given action for a given observation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Union, Any
from typing import List, Dict, Union, Any, Optional
import numpy as np
from numpy.typing import NDArray
import torch
import math
import iteround
Expand All @@ -19,7 +20,7 @@ class PPOPolicy(Policy):
A neural network policy learned with PPO
"""

def __init__(self, model, simulation_name: str, save_path: str, player_type: PlayerType, states: List[State],
def __init__(self, model: Optional[PPO], simulation_name: str, save_path: str, player_type: PlayerType, states: List[State],
actions: List[Action], experiment_config: ExperimentConfig, avg_R: float):
"""
Initializes the policy
Expand Down Expand Up @@ -50,14 +51,16 @@ def __init__(self, model, simulation_name: str, save_path: str, player_type: Pla
self.avg_R = avg_R
self.policy_type = PolicyType.PPO

def action(self, o: Union[List[float], List[int]]) -> Union[int, List[int], np.ndarray[Any, Any]]:
def action(self, o: Union[List[float], List[int]]) -> NDArray[Any]:
"""
Multi-threshold stopping policy
:param o: the current observation
:return: the selected action
"""
a, _ = self.model.predict(np.array(o), deterministic=False)
if self.model is None:
raise ValueError("The model is None")
a = self.model.predict(np.array(o), deterministic=False)[0]
return a

def probability(self, o: Union[List[float], List[int]], a: int) -> float:
Expand All @@ -68,6 +71,8 @@ def probability(self, o: Union[List[float], List[int]], a: int) -> float:
:param o: the action
:return: the probability of the action
"""
if self.model is None:
raise ValueError("The model is None")
prob = math.exp(self.model.policy.get_distribution(obs=torch.tensor([o]).to(self.model.device)).log_prob(
actions=torch.tensor(a)).item())
return prob
Expand Down Expand Up @@ -134,6 +139,8 @@ def _get_attacker_dist(self, obs) -> List[float]:
:return: the conditional ation distribution
"""
obs = np.array([obs])
if self.model is None:
raise ValueError("The model is None")
actions, values, log_prob = self.model.policy.forward(obs=torch.tensor(obs).to(self.model.device))
action = actions[0]
if action == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,13 @@ def on_assign(consumer, partitions):
agg_docker_stats.append(DockerStats.from_kafka_record(record=msg.value().decode()))
elif topic == collector_constants.KAFKA_CONFIG.HOST_METRICS_TOPIC_NAME:
metrics = HostMetrics.from_kafka_record(record=msg.value().decode())
if emulation_env_config.get_container_from_ip(metrics.ip) is None:
raise ValueError("NodeContainerConfig is None")
c = emulation_env_config.get_container_from_ip(metrics.ip)
c_1 = emulation_env_config.get_container_from_ip(metrics.ip)
host_metrics[c.get_full_name()].append(metrics)
host_metrics_counter += 1
total_host_metrics.append(metrics)
elif topic == collector_constants.KAFKA_CONFIG.OSSEC_IDS_LOG_TOPIC_NAME:
metrics = OSSECIdsAlertCounters.from_kafka_record(record=msg.value().decode())
c = emulation_env_config.get_container_from_ip(metrics.ip)
c_1 = emulation_env_config.get_container_from_ip(metrics.ip)
ossec_host_ids_metrics[c.get_full_name()].append(metrics)
ossec_host_metrics_counter += 1
total_ossec_metrics.append(metrics)
Expand All @@ -174,17 +172,17 @@ def on_assign(consumer, partitions):
defender_actions.append(EmulationDefenderAction.from_kafka_record(record=msg.value().decode()))
elif topic == collector_constants.KAFKA_CONFIG.DOCKER_HOST_STATS_TOPIC_NAME:
stats = DockerStats.from_kafka_record(record=msg.value().decode())
c = emulation_env_config.get_container_from_ip(stats.ip)
c_1 = emulation_env_config.get_container_from_ip(stats.ip)
docker_host_stats[c.get_full_name()].append(stats)
elif topic == collector_constants.KAFKA_CONFIG.SNORT_IDS_LOG_TOPIC_NAME:
metrics = SnortIdsAlertCounters.from_kafka_record(record=msg.value().decode())
c = emulation_env_config.get_container_from_ip(metrics.ip)
c_1 = emulation_env_config.get_container_from_ip(metrics.ip)
snort_alert_metrics_per_ids[c.get_full_name()].append(metrics)
snort_metrics_counter += 1
total_snort_metrics.append(metrics)
elif topic == collector_constants.KAFKA_CONFIG.SNORT_IDS_RULE_LOG_TOPIC_NAME:
metrics = SnortIdsRuleCounters.from_kafka_record(record=msg.value().decode())
c = emulation_env_config.get_container_from_ip(metrics.ip)
c_1 = emulation_env_config.get_container_from_ip(metrics.ip)
snort_rule_metrics_per_ids[c.get_full_name()].append(metrics)
snort_rule_metrics_counter += 1
total_snort_rule_metrics.append(metrics)
Expand Down Expand Up @@ -227,8 +225,8 @@ def on_assign(consumer, partitions):
agg_flow_statistics_record)
elif topic == collector_constants.KAFKA_CONFIG.SNORT_IDS_IP_LOG_TOPIC_NAME:
metrics = SnortIdsIPAlertCounters.from_kafka_record(record=msg.value().decode())
c = emulation_env_config.get_container_from_ip(metrics.alert_ip)
if c is not None:
c_1 = emulation_env_config.get_container_from_ip(metrics.alert_ip)
if c_1 is not None:
snort_ids_ip_metrics[c.get_full_name()].append(metrics)
if host_metrics_counter >= len(emulation_env_config.containers_config.containers):
agg_host_metrics_dto = ReadEmulationStatisticsUtil.average_host_metrics(
Expand Down

0 comments on commit 2e28a83

Please sign in to comment.