Skip to content

Commit

Permalink
Yaml DOS checks
Browse files Browse the repository at this point in the history
  • Loading branch information
kshakir committed Jun 21, 2019
1 parent 5142bf1 commit 778b201
Show file tree
Hide file tree
Showing 13 changed files with 334 additions and 23 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ which now has been updated to `services.Instrumentation.config`. More info on it
A new experimental feature, the `cached-copy` localization strategy is available for the shared filesystem.
More information can be found in the [documentation on localization](https://cromwell.readthedocs.io/en/stable/backends/HPC).

#### Yaml node limits

Yaml parsing now checks for cycles, and limits the maximum number of parsed nodes to a configurable value. See
[the documentation on configuring Yaml](https://cromwell.readthedocs.io/en/stable/Configuring/#yaml) for more
information.

### API Changes

#### Workflow Metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ submit {
statusCode: 400
message: """{
"status": "fail",
"message": "Error(s): Input file is not a valid yaml or json. Inputs data: ''. Error: MatchError: null."
"message": "Error(s): Input file is not a valid yaml or json. Inputs data: ''. Error: ParsingFailure: null."
}"""
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package centaur.cwl
import better.files.File
import com.typesafe.config.Config
import common.util.StringUtil._
import common.validation.IOChecked.IOChecked
import cwl.preprocessor.CwlPreProcessor
import io.circe.optics.JsonPath
import io.circe.optics.JsonPath._
import io.circe.yaml.Printer.StringStyle
import io.circe.{Json, yaml}
import net.ceedubs.ficus.Ficus._
import common.util.StringUtil._
import wom.util.YamlUtils

/**
* Tools to pre-process the CWL workflows and inputs before feeding them to Cromwell so they can be executed on PAPI.
Expand Down Expand Up @@ -39,7 +40,7 @@ class CloudPreprocessor(config: Config, prefixConfigPath: String) {

// Parse value, apply f to it, and print it back to String using the printer
private def process(value: String, f: Json => Json, printer: Json => String) = {
yaml.parser.parse(value) match {
YamlUtils.parse(value) match {
case Left(error) => throw new Exception(error.getMessage)
case Right(json) => printer(f(json))
}
Expand Down
5 changes: 3 additions & 2 deletions centaurCwlRunner/src/test/scala/CloudPreprocessorSpec.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import centaur.cwl.CloudPreprocessor
import com.typesafe.config.ConfigFactory
import org.scalatest.{FlatSpec, Matchers}
import wom.util.YamlUtils

class CloudPreprocessorSpec extends FlatSpec with Matchers {
behavior of "PAPIPreProcessor"

val pAPIPreprocessor = new CloudPreprocessor(ConfigFactory.load(), "papi.default-input-gcs-prefix")

def validate(result: String, expectation: String) = {
val parsedResult = io.circe.yaml.parser.parse(result).right.get
val parsedExpectation = io.circe.yaml.parser.parse(expectation).right.get
val parsedResult = YamlUtils.parse(result).right.get
val parsedExpectation = YamlUtils.parse(expectation).right.get

// This is an actual Json comparison from circe
parsedResult shouldBe parsedExpectation
Expand Down
4 changes: 3 additions & 1 deletion cwl/src/main/scala/cwl/CommandLineTool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import wom.callable.{Callable, CallableTaskDefinition, ContainerizedInputExpress
import wom.expression.{IoFunctionSet, ValueAsAnExpression, WomExpression}
import wom.graph.GraphNodePort.OutputPort
import wom.types.{WomArrayType, WomIntegerType, WomOptionalType}
import wom.util.YamlUtils
import wom.values.{WomArray, WomEvaluatedCallInputs, WomGlobFile, WomInteger, WomString, WomValue}
import wom.{CommandPart, RuntimeAttributes, RuntimeAttributesKeys}

Expand Down Expand Up @@ -167,8 +168,9 @@ case class CommandLineTool private(

// Parse content as json and return output values for each output port
def parseContent(content: String): EvaluatedOutputs = {
val yaml = YamlUtils.parse(content)
for {
parsed <- io.circe.yaml.parser.parse(content).flatMap(_.as[Map[String, Json]]).leftMap(error => NonEmptyList.one(error.getMessage))
parsed <- yaml.flatMap(_.as[Map[String, Json]]).leftMap(error => NonEmptyList.one(error.getMessage))
jobOutputsMap <- jsonToOutputs(parsed)
} yield jobOutputsMap.toMap
}
Expand Down
5 changes: 3 additions & 2 deletions cwl/src/main/scala/cwl/CwlExecutableValidation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ package cwl
import common.Checked
import common.validation.Checked._
import io.circe.Json
import io.circe.yaml
import wom.callable.{ExecutableCallable, TaskDefinition}
import wom.executable.Executable
import wom.executable.Executable.{InputParsingFunction, ParsedInputMap}
import wom.expression.IoFunctionSet
import wom.util.YamlUtils

object CwlExecutableValidation {

// Decodes the input file, and build the ParsedInputMap
private val inputCoercionFunction: InputParsingFunction =
inputFile => {
yaml.parser.parse(inputFile).flatMap(_.as[Map[String, Json]]) match {
val yaml = YamlUtils.parse(inputFile)
yaml.flatMap(_.as[Map[String, Json]]) match {
case Left(error) => error.getMessage.invalidNelCheck[ParsedInputMap]
case Right(inputValue) => inputValue.map({ case (key, value) => key -> value.foldWith(CwlJsonToDelayedCoercionFunction) }).validNelCheck
}
Expand Down
4 changes: 3 additions & 1 deletion cwl/src/main/scala/cwl/preprocessor/CwlPreProcessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import io.circe.optics.JsonPath._
import io.circe.{Json, JsonNumber, JsonObject}
import mouse.all._
import org.slf4j.LoggerFactory
import wom.util.YamlUtils

import scala.concurrent.ExecutionContext

Expand Down Expand Up @@ -218,7 +219,8 @@ object CwlPreProcessor {
}

private [preprocessor] def parseYaml(in: String): IOChecked[Json] = {
io.circe.yaml.parser.parse(in).leftMap(error => NonEmptyList.one(error.message)).toIOChecked
val yaml = YamlUtils.parse(in)
yaml.leftMap(error => NonEmptyList.one(error.message)).toIOChecked
}

/**
Expand Down
20 changes: 18 additions & 2 deletions docs/Configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ url = "jdbc:mysql://host/cromwell?rewriteBatchedStatements=true&serverTimezone=U

Using this option does not alter your database's underlying timezone; rather, it causes Cromwell to "speak UTC" when communicating with the DB, and the DB server performs the conversion for you.

## Abort
### Abort

**Control-C (SIGINT) abort handler**

Expand Down Expand Up @@ -514,7 +514,7 @@ Cromwell writes one batch of workflow heartbeats at a time. While the internal q
a configurable threshold then [instrumentation](developers/Instrumentation.md) may send a metric signal that the
heartbeat load is above normal.

This threshold may be configured the configuration value:
This threshold may be configured via the configuration value:

```hocon
system.workflow-heartbeats {
Expand All @@ -523,3 +523,19 @@ system.workflow-heartbeats {
```

The default threshold value is 100, just like the default for the heartbeat batch size.

### YAML

Cromwell will throw an error when detecting cyclic loops in Yaml inputs. However one can craft small acyclic YAML
documents that consume significant amounts of memory or cpu. To limit the amount of processing during parsing, there is
a limit on the number of nodes parsed per YAML document.

This limit may be configured via the configuration value:

```hocon
yaml {
max-nodes = 1000000
}
```

The default limit is 1,000,000 nodes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cromwell.webservice

import java.net.URL

import _root_.io.circe.yaml
import akka.util.ByteString
import cats.data.NonEmptyList
import cats.data.Validated._
Expand All @@ -21,6 +20,7 @@ import org.slf4j.LoggerFactory
import spray.json._
import wdl.draft2.model.WorkflowJson
import wom.core._
import wom.util.YamlUtils

import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -141,7 +141,7 @@ object PartialWorkflowSources {
val onHold: ErrorOr[Boolean] = getBooleanValue(workflowOnHoldKey).getOrElse(false.validNel)

(unrecognized, workflowSourceFinal, workflowInputs, workflowInputsAux, workflowDependenciesFinal, onHold) mapN {
case (_, source, inputs, aux, dep, onHold) => PartialWorkflowSources(
case (_, source, inputs, aux, dep, onHoldActual) => PartialWorkflowSources(
workflowSource = source,
workflowUrl = workflowUrl,
workflowRoot = getStringValue(WorkflowRootKey),
Expand All @@ -153,7 +153,7 @@ object PartialWorkflowSources {
customLabels = getStringValue(labelsKey),
zippedImports = dep,
warnings = wdlSourceWarning.toVector ++ wdlDependenciesWarning.toVector,
workflowOnHold = onHold
workflowOnHold = onHoldActual
)
}
}
Expand All @@ -169,7 +169,7 @@ object PartialWorkflowSources {
import cats.syntax.validated._

val parseInputsTry = Try {
yaml.parser.parse(data) match {
YamlUtils.parse(data) match {
// If it's an array, treat each element as an individual input object, otherwise simply toString the whole thing
case Right(json) => json.asArray.map(_.map(_.toString())).getOrElse(Vector(json.pretty(Printer.noSpaces))).validNel
case Left(error) => s"Input file is not a valid yaml or json. Inputs data: '$data'. Error: ${ExceptionUtils.getMessage(error)}.".invalidNel
Expand Down Expand Up @@ -265,12 +265,11 @@ object PartialWorkflowSources {

private def toMap(someInput: Option[String]): ErrorOr[Map[String, JsValue]] = {
someInput match {
case Some(input: String) => {
case Some(input: String) =>
Try(input.parseJson).toErrorOrWithContext(s"parse input: '$input', which is not a valid json. Please check for syntactical errors.") flatMap {
case JsObject(inputMap) => inputMap.validNel
case j: JsValue => s"Submitted input '$input' of type ${j.getClass.getSimpleName} is not a JSON object.".invalidNel
}
}
case None => Map.empty[String, JsValue].validNel
}
}
Expand Down
10 changes: 4 additions & 6 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object Dependencies {
private val paradiseV = "2.1.1"
private val pegdownV = "1.6.0"
private val rdf4jV = "2.4.2"
private val refinedV = "0.9.4"
private val refinedV = "0.9.8"
private val rhinoV = "1.7.10"
private val scalaGraphV = "1.12.5"
private val scalaLoggingV = "3.9.2"
Expand Down Expand Up @@ -370,8 +370,7 @@ object Dependencies {
"org.scalacheck" %% "scalacheck" % scalacheckV % Test,
"com.github.mpilquist" %% "simulacrum" % simulacrumV,
"commons-codec" % "commons-codec" % commonsCodecV,
"eu.timepit" %% "refined" % refinedV
)
) ++ circeDependencies ++ refinedTypeDependenciesList

val wdlDependencies = List(
"commons-io" % "commons-io" % commonsIoV,
Expand Down Expand Up @@ -421,15 +420,14 @@ object Dependencies {
"org.javadelight" % "delight-rhino-sandbox" % delightRhinoSandboxV,
"org.scalamock" %% "scalamock" % scalamockV % Test,
"commons-io" % "commons-io" % commonsIoV % Test
) ++ circeDependencies ++ womDependencies ++ refinedTypeDependenciesList ++ betterFilesDependencies ++
owlApiDependencies
) ++ betterFilesDependencies ++ owlApiDependencies

val womtoolDependencies = catsDependencies ++ slf4jBindingDependencies

val centaurCwlRunnerDependencies = List(
"com.github.scopt" %% "scopt" % scoptV,
"io.circe" %% "circe-optics" % circeOpticsV
) ++ slf4jBindingDependencies ++ circeDependencies
) ++ slf4jBindingDependencies

val coreDependencies = List(
"com.google.auth" % "google-auth-library-oauth2-http" % googleOauth2V,
Expand Down
4 changes: 4 additions & 0 deletions wom/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
yaml {
# The maximum number of nodes (scalars + sequences + mappings) that will be parsed in a YAML
max-nodes = 1000000
}
105 changes: 105 additions & 0 deletions wom/src/main/scala/wom/util/YamlUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package wom.util

import java.io.StringReader

import com.typesafe.config.ConfigException.BadValue
import com.typesafe.config.{Config, ConfigFactory}
import eu.timepit.refined.api.Refined
import eu.timepit.refined.numeric.NonNegative
import eu.timepit.refined.refineV
import io.circe.{Json, ParsingFailure}
import net.ceedubs.ficus.Ficus._
import net.ceedubs.ficus.readers.ValueReader

import scala.collection.JavaConverters._

object YamlUtils {

private[util] implicit val refinedNonNegativeReader: ValueReader[Int Refined NonNegative] = {
(config: Config, path: String) => {
val int = config.getInt(path)
refineV[NonNegative](int) match {
case Left(error) => throw new BadValue(path, error)
case Right(refinedInt) => refinedInt
}
}
}

private val defaultMaxNodes = ConfigFactory.load().as[Int Refined NonNegative]("yaml.max-nodes")

/**
* Parses the yaml, detecting loops, and a maximum number of nodes.
*
* See: https://en.wikipedia.org/wiki/Billion_laughs_attack
*
* @param yaml The yaml string.
* @param maxNodes Maximum number of yaml nodes to parse.
* @return The parsed yaml.
*/
def parse(yaml: String, maxNodes: Int Refined NonNegative = defaultMaxNodes): Either[ParsingFailure, Json] = {
try {
val snakeYamlParser = new org.yaml.snakeyaml.Yaml()
val parsed = snakeYamlParser.load[AnyRef](new StringReader(yaml))
// Use identity comparisons to check if two nodes are the same and do NOT recurse into them checking equality
val identityHashMap = new java.util.IdentityHashMap[AnyRef, java.lang.Boolean]()
// Since we don't actually need the values, wrap the Map in a Set
val identitySet = java.util.Collections.newSetFromMap(identityHashMap)
searchForOversizedYaml(parsed, identitySet, maxNodes, new Counter)
io.circe.yaml.parser.parse(yaml)
} catch {
case exception: Exception =>
Left(ParsingFailure(exception.getMessage, exception))
}
}

/** A "pointer" reference to a mutable count. */
private class Counter {
var count = 0L
}

/**
* Looks for loops and large documents in yaml parsed by SnakeYaml.
*
* Possibly can be refactored if/when Circe switches to SnakeYaml-Engine (aka Yaml 1.2), as SY-E disables recursive
* key by default:
*
* - https://github.com/circe/circe-yaml/issues/46
* - https://bitbucket.org/asomov/snakeyaml/issues/432/aggressive-yaml-anchors-causing
* - https://bitbucket.org/asomov/snakeyaml-engine/commits/7573a8b4d551fc84521f4ac1234a361bfbc96698
*
* @param node Current node to be evaluated.
* @param identitySet Previously seen nodes during this branch of cycle checking.
* @param maxNodes The maximum number of nodes allowed to be traversed.
* @param counter A counter tracking to the total number of nodes traversed during the entire traversal.
*/
private def searchForOversizedYaml(node: AnyRef,
identitySet: java.util.Set[AnyRef],
maxNodes: Int Refined NonNegative,
counter: Counter): Unit = {
if (!identitySet.add(node)) {
throw new IllegalArgumentException("Loop detected")
}

counter.count += 1
if (counter.count > maxNodes.value) {
throw new IllegalArgumentException(s"Loop detection halted at $maxNodes nodes")
}

node match {
case iterable: java.lang.Iterable[AnyRef]@unchecked =>
iterable.asScala foreach {
searchForOversizedYaml(_, identitySet, maxNodes, counter)
}
case map: java.util.Map[AnyRef, AnyRef]@unchecked =>
map.asScala foreach {
case (key, value) =>
searchForOversizedYaml(key, identitySet, maxNodes, counter)
searchForOversizedYaml(value, identitySet, maxNodes, counter)
}
case _ => /* ignore scalars, only loop through Yaml sequences and mappings: https://yaml.org/spec/1.1/#id861435 */
}

identitySet.remove(node)
()
}
}
Loading

0 comments on commit 778b201

Please sign in to comment.