From df1a227cc7ec3cc5fa51a937714c303c2eb193b3 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sun, 13 Jan 2019 09:59:23 -0800 Subject: [PATCH] Unbounded depth --- .travis.yml | 1 + Cargo.toml | 13 +++- src/de.rs | 162 +++++++++++++++++++++++++++++++++----------------- tests/test.rs | 13 ++++ 4 files changed, 132 insertions(+), 57 deletions(-) diff --git a/.travis.yml b/.travis.yml index ef58beb..1e75da9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ matrix: - cargo test --features preserve_order - cargo test --features arbitrary_precision - cargo test --features raw_value + - cargo test --features unbounded_depth - rust: 1.15.0 script: diff --git a/Cargo.toml b/Cargo.toml index 95a170c..e4769c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,9 +25,10 @@ ryu = "0.2" compiletest_rs = { version = "0.3", features = ["stable"] } serde_bytes = "0.10" serde_derive = "1.0" +serde_stacker = "0.1" [package.metadata.docs.rs] -features = ["raw_value"] +features = ["raw_value", "unbounded_depth"] [package.metadata.playground] features = ["raw_value"] @@ -50,3 +51,13 @@ arbitrary_precision = [] # Provide a RawValue type that can hold unprocessed JSON during deserialization. raw_value = [] + +# Provide a method disable_recursion_limit to parse arbitrarily deep JSON +# structures without any consideration for overflowing the stack. When using +# this feature, you will want to provide some other way to protect against stack +# overflows, such as by wrapping your Deserializer in the dynamically growing +# stack adapter provided by the serde_stacker crate. Additionally you will need +# to be careful around other recursive operations on the parsed result which may +# overflow the stack after deserialization has completed, including, but not +# limited to, Display and Debug and Drop impls. +unbounded_depth = [] diff --git a/src/de.rs b/src/de.rs index 0a91a0b..617fac7 100644 --- a/src/de.rs +++ b/src/de.rs @@ -25,6 +25,8 @@ pub struct Deserializer { read: R, scratch: Vec, remaining_depth: u8, + #[cfg(feature = "unbounded_depth")] + disable_recursion_limit: bool, } impl<'de, R> Deserializer @@ -44,6 +46,8 @@ where read: read, scratch: Vec::new(), remaining_depth: 128, + #[cfg(feature = "unbounded_depth")] + disable_recursion_limit: false, } } } @@ -140,6 +144,54 @@ impl<'de, R: Read<'de>> Deserializer { } } + /// Parse arbitrarily deep JSON structures without any consideration for + /// overflowing the stack. + /// + /// You will want to provide some other way to protect against stack + /// overflows, such as by wrapping your Deserializer in the dynamically + /// growing stack adapter provided by the serde_stacker crate. Additionally + /// you will need to be careful around other recursive operations on the + /// parsed result which may overflow the stack after deserialization has + /// completed, including, but not limited to, Display and Debug and Drop + /// impls. + /// + /// *This method is only available if serde_json is built with the + /// `"unbounded_depth"` feature.* + /// + /// # Examples + /// + /// ```edition2018 + /// use serde::Deserialize; + /// use serde_json::Value; + /// + /// fn main() { + /// let mut json = String::new(); + /// for _ in 0..10000 { + /// json = format!("[{}]", json); + /// } + /// + /// let mut deserializer = serde_json::Deserializer::from_str(&json); + /// deserializer.disable_recursion_limit(); + /// let deserializer = serde_stacker::Deserializer::new(&mut deserializer); + /// let value = Value::deserialize(deserializer).unwrap(); + /// + /// carefully_drop_nested_arrays(value); + /// } + /// + /// fn carefully_drop_nested_arrays(value: Value) { + /// let mut stack = vec![value]; + /// while let Some(value) = stack.pop() { + /// if let Value::Array(array) = value { + /// stack.extend(array); + /// } + /// } + /// } + /// ``` + #[cfg(feature = "unbounded_depth")] + pub fn disable_recursion_limit(&mut self) { + self.disable_recursion_limit = true; + } + fn peek(&mut self) -> Result> { self.read.peek() } @@ -983,6 +1035,39 @@ macro_rules! deserialize_prim_number { } } +#[cfg(not(feature = "unbounded_depth"))] +macro_rules! if_checking_recursion_limit { + ($($body:tt)*) => { + $($body)* + }; +} + +#[cfg(feature = "unbounded_depth")] +macro_rules! if_checking_recursion_limit { + ($self:ident $($body:tt)*) => { + if !$self.disable_recursion_limit { + $self $($body)* + } + }; +} + +macro_rules! check_recursion { + ($self:ident $($body:tt)*) => { + if_checking_recursion_limit! { + $self.remaining_depth -= 1; + if $self.remaining_depth == 0 { + return Err($self.peek_error(ErrorCode::RecursionLimitExceeded)); + } + } + + $self $($body)* + + if_checking_recursion_limit! { + $self.remaining_depth += 1; + } + }; +} + impl<'de, 'a, R: Read<'de>> de::Deserializer<'de> for &'a mut Deserializer { type Error = Error; @@ -1028,32 +1113,22 @@ impl<'de, 'a, R: Read<'de>> de::Deserializer<'de> for &'a mut Deserializer { } } b'[' => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); } - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self)); - - self.remaining_depth += 1; - match (ret, self.end_seq()) { (Ok(ret), Ok(())) => Ok(ret), (Err(err), _) | (_, Err(err)) => Err(err), } } b'{' => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); } - self.eat_char(); - let ret = visitor.visit_map(MapAccess::new(self)); - - self.remaining_depth += 1; - match (ret, self.end_map()) { (Ok(ret), Ok(())) => Ok(ret), (Err(err), _) | (_, Err(err)) => Err(err), @@ -1414,16 +1489,11 @@ impl<'de, 'a, R: Read<'de>> de::Deserializer<'de> for &'a mut Deserializer { let value = match peek { b'[' => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); } - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self)); - - self.remaining_depth += 1; - match (ret, self.end_seq()) { (Ok(ret), Ok(())) => Ok(ret), (Err(err), _) | (_, Err(err)) => Err(err), @@ -1470,16 +1540,11 @@ impl<'de, 'a, R: Read<'de>> de::Deserializer<'de> for &'a mut Deserializer { let value = match peek { b'{' => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); } - self.eat_char(); - let ret = visitor.visit_map(MapAccess::new(self)); - - self.remaining_depth += 1; - match (ret, self.end_map()) { (Ok(ret), Ok(())) => Ok(ret), (Err(err), _) | (_, Err(err)) => Err(err), @@ -1512,32 +1577,22 @@ impl<'de, 'a, R: Read<'de>> de::Deserializer<'de> for &'a mut Deserializer { let value = match peek { b'[' => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); } - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self)); - - self.remaining_depth += 1; - match (ret, self.end_seq()) { (Ok(ret), Ok(())) => Ok(ret), (Err(err), _) | (_, Err(err)) => Err(err), } } b'{' => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); } - self.eat_char(); - let ret = visitor.visit_map(MapAccess::new(self)); - - self.remaining_depth += 1; - match (ret, self.end_map()) { (Ok(ret), Ok(())) => Ok(ret), (Err(err), _) | (_, Err(err)) => Err(err), @@ -1566,16 +1621,11 @@ impl<'de, 'a, R: Read<'de>> de::Deserializer<'de> for &'a mut Deserializer { { match try!(self.parse_whitespace()) { Some(b'{') => { - self.remaining_depth -= 1; - if self.remaining_depth == 0 { - return Err(self.peek_error(ErrorCode::RecursionLimitExceeded)); + check_recursion! { + self.eat_char(); + let value = try!(visitor.visit_enum(VariantAccess::new(self))); } - self.eat_char(); - let value = try!(visitor.visit_enum(VariantAccess::new(self))); - - self.remaining_depth += 1; - match try!(self.parse_whitespace()) { Some(b'}') => { self.eat_char(); diff --git a/tests/test.rs b/tests/test.rs index 46b3396..ac8d9e1 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1764,6 +1764,19 @@ fn test_stack_overflow() { test_parse_err::(&[(&brackets, "recursion limit exceeded at line 1 column 128")]); } +#[test] +#[cfg(feature = "unbounded_depth")] +fn test_disable_recursion_limit() { + let brackets: String = iter::repeat('[') + .take(140) + .chain(iter::repeat(']').take(140)) + .collect(); + + let mut deserializer = Deserializer::from_str(&brackets); + deserializer.disable_recursion_limit(); + Value::deserialize(&mut deserializer).unwrap(); +} + #[test] fn test_integer_key() { // map with integer keys