From 1837bf775cf7f3c7fe0f01aa29d85aaf21ff0392 Mon Sep 17 00:00:00 2001 From: JT <547158+jntrnr@users.noreply.github.com> Date: Mon, 7 Mar 2022 15:08:56 -0500 Subject: [PATCH] Default values (#4770) --- crates/nu-engine/src/eval.rs | 11 ++ crates/nu-parser/src/errors.rs | 8 ++ crates/nu-parser/src/parse_keywords.rs | 1 + crates/nu-parser/src/parser.rs | 133 +++++++++++++++++- .../src/serializers/capnp/signature.rs | 2 + crates/nu-protocol/src/ast/call.rs | 4 +- crates/nu-protocol/src/ast/cell_path.rs | 2 +- crates/nu-protocol/src/ast/expr.rs | 3 +- crates/nu-protocol/src/ast/expression.rs | 4 +- crates/nu-protocol/src/ast/import_pattern.rs | 8 +- crates/nu-protocol/src/ast/operator.rs | 2 +- crates/nu-protocol/src/signature.rs | 16 ++- crates/nu-protocol/src/span.rs | 2 +- crates/nu-protocol/src/ty.rs | 30 ++++ crates/nu-protocol/src/value/unit.rs | 4 +- crates/nu-protocol/tests/test_signature.rs | 15 +- src/tests/test_engine.rs | 63 +++++++++ 17 files changed, 289 insertions(+), 19 deletions(-) diff --git a/crates/nu-engine/src/eval.rs b/crates/nu-engine/src/eval.rs index 7bee98fbfc..322ea0cd74 100644 --- a/crates/nu-engine/src/eval.rs +++ b/crates/nu-engine/src/eval.rs @@ -62,6 +62,9 @@ fn eval_call( if let Some(arg) = call.positional.get(param_idx) { let result = eval_expression(engine_state, caller_stack, arg)?; callee_stack.add_var(var_id, result); + } else if let Some(arg) = ¶m.default_value { + let result = eval_expression(engine_state, caller_stack, arg)?; + callee_stack.add_var(var_id, result); } else { callee_stack.add_var(var_id, Value::nothing(call.head)); } @@ -103,6 +106,10 @@ fn eval_call( if let Some(arg) = &call_named.1 { let result = eval_expression(engine_state, caller_stack, arg)?; + callee_stack.add_var(var_id, result); + } else if let Some(arg) = &named.default_value { + let result = eval_expression(engine_state, caller_stack, arg)?; + callee_stack.add_var(var_id, result); } else { callee_stack.add_var( @@ -126,6 +133,10 @@ fn eval_call( span: call.head, }, ) + } else if let Some(arg) = &named.default_value { + let result = eval_expression(engine_state, caller_stack, arg)?; + + callee_stack.add_var(var_id, result); } else { callee_stack.add_var(var_id, Value::Nothing { span: call.head }) } diff --git a/crates/nu-parser/src/errors.rs b/crates/nu-parser/src/errors.rs index dfd7734416..6e31ce7995 100644 --- a/crates/nu-parser/src/errors.rs +++ b/crates/nu-parser/src/errors.rs @@ -19,6 +19,13 @@ pub enum ParseError { #[diagnostic(code(nu::parser::extra_positional), url(docsrs), help("Usage: {0}"))] ExtraPositional(String, #[label = "extra positional argument"] Span), + #[error("Require positional parameter after optional parameter")] + #[diagnostic(code(nu::parser::required_after_optional), url(docsrs))] + RequiredAfterOptional( + String, + #[label = "required parameter {0} after optional parameter"] Span, + ), + #[error("Unexpected end of code.")] #[diagnostic(code(nu::parser::unexpected_eof), url(docsrs))] UnexpectedEof(String, #[label("expected closing {0}")] Span), @@ -246,6 +253,7 @@ impl ParseError { ParseError::UnknownCommand(s) => *s, ParseError::NonUtf8(s) => *s, ParseError::UnknownFlag(_, _, s) => *s, + ParseError::RequiredAfterOptional(_, s) => *s, ParseError::UnknownType(s) => *s, ParseError::MissingFlagParam(_, s) => *s, ParseError::ShortFlagBatchCantTakeArg(s) => *s, diff --git a/crates/nu-parser/src/parse_keywords.rs b/crates/nu-parser/src/parse_keywords.rs index a54e0daed6..18ca896fb3 100644 --- a/crates/nu-parser/src/parse_keywords.rs +++ b/crates/nu-parser/src/parse_keywords.rs @@ -176,6 +176,7 @@ pub fn parse_for( desc: String::new(), shape: SyntaxShape::Any, var_id: Some(*var_id), + default_value: None, }, ); } diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index 8bfdb9efa1..607637cbb8 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -2677,6 +2677,7 @@ pub fn parse_row_condition( desc: "row condition".into(), shape: SyntaxShape::Any, var_id: Some(var_id), + default_value: None, }); working_set.add_block(block) @@ -2742,11 +2743,14 @@ pub fn parse_signature_helper( working_set: &mut StateWorkingSet, span: Span, ) -> (Box, Option) { + #[allow(clippy::enum_variant_names)] enum ParseMode { ArgMode, TypeMode, + DefaultValueMode, } + #[derive(Debug)] enum Arg { Positional(PositionalArg, bool), // bool - required RestPositional(PositionalArg), @@ -2756,7 +2760,13 @@ pub fn parse_signature_helper( let mut error = None; let source = working_set.get_span_contents(span); - let (output, err) = lex(source, span.start, &[b'\n', b'\r', b','], &[b':'], false); + let (output, err) = lex( + source, + span.start, + &[b'\n', b'\r', b','], + &[b':', b'='], + false, + ); error = error.or(err); let mut args: Vec = vec![]; @@ -2776,12 +2786,24 @@ pub fn parse_signature_helper( ParseMode::ArgMode => { parse_mode = ParseMode::TypeMode; } - ParseMode::TypeMode => { + ParseMode::TypeMode | ParseMode::DefaultValueMode => { // We're seeing two types for the same thing for some reason, error error = error.or_else(|| Some(ParseError::Expected("type".into(), span))); } } + } else if contents == b"=" { + match parse_mode { + ParseMode::ArgMode | ParseMode::TypeMode => { + parse_mode = ParseMode::DefaultValueMode; + } + ParseMode::DefaultValueMode => { + // We're seeing two default values for some reason, error + error = error.or_else(|| { + Some(ParseError::Expected("default value".into(), span)) + }); + } + } } else { match parse_mode { ParseMode::ArgMode => { @@ -2802,6 +2824,7 @@ pub fn parse_signature_helper( short: None, required: false, var_id: Some(var_id), + default_value: None, })); } else { let short_flag = &flags[1]; @@ -2832,6 +2855,7 @@ pub fn parse_signature_helper( short: Some(chars[0]), required: false, var_id: Some(var_id), + default_value: None, })); } else { error = error.or_else(|| { @@ -2864,6 +2888,7 @@ pub fn parse_signature_helper( short: Some(chars[0]), required: false, var_id: Some(var_id), + default_value: None, })); } else if contents.starts_with(b"(-") { let short_flag = &contents[2..]; @@ -2921,6 +2946,7 @@ pub fn parse_signature_helper( name, shape: SyntaxShape::Any, var_id: Some(var_id), + default_value: None, }, false, )) @@ -2935,6 +2961,7 @@ pub fn parse_signature_helper( name, shape: SyntaxShape::Any, var_id: Some(var_id), + default_value: None, })); } else { let name = String::from_utf8_lossy(contents).to_string(); @@ -2949,6 +2976,7 @@ pub fn parse_signature_helper( name, shape: SyntaxShape::Any, var_id: Some(var_id), + default_value: None, }, true, )) @@ -2982,6 +3010,97 @@ pub fn parse_signature_helper( } parse_mode = ParseMode::ArgMode; } + ParseMode::DefaultValueMode => { + if let Some(last) = args.last_mut() { + let (expression, err) = + parse_value(working_set, span, &SyntaxShape::Any); + error = error.or(err); + + //TODO check if we're replacing a custom parameter already + match last { + Arg::Positional( + PositionalArg { + shape, + var_id, + default_value, + .. + }, + required, + ) => { + let var_id = var_id.expect("internal error: all custom parameters must have var_ids"); + let var_type = working_set.get_variable(var_id); + match var_type { + Type::Unknown => { + working_set.set_variable_type( + var_id, + expression.ty.clone(), + ); + } + t => { + if t != &expression.ty { + error = error.or_else(|| { + Some(ParseError::AssignmentMismatch( + "Default value wrong type".into(), + format!("default value not {}", t), + expression.span, + )) + }) + } + } + } + *shape = expression.ty.to_shape(); + *default_value = Some(expression); + *required = false; + } + Arg::RestPositional(..) => { + error = error.or_else(|| { + Some(ParseError::AssignmentMismatch( + "Rest parameter given default value".into(), + "can't have default value".into(), + expression.span, + )) + }) + } + Arg::Flag(Flag { + arg, + var_id, + default_value, + .. + }) => { + let var_id = var_id.expect("internal error: all custom parameters must have var_ids"); + let var_type = working_set.get_variable(var_id); + + let expression_ty = expression.ty.clone(); + let expression_span = expression.span; + + *default_value = Some(expression); + + // Flags with a boolean type are just present/not-present switches + if var_type != &Type::Bool { + match var_type { + Type::Unknown => { + *arg = Some(expression_ty.to_shape()); + working_set + .set_variable_type(var_id, expression_ty); + } + t => { + if t != &expression_ty { + error = error.or_else(|| { + Some(ParseError::AssignmentMismatch( + "Default value wrong type".into(), + format!("default value not {}", t), + expression_span, + )) + }) + } + } + } + } + } + } + } + parse_mode = ParseMode::ArgMode; + } } } } @@ -3030,6 +3149,14 @@ pub fn parse_signature_helper( match arg { Arg::Positional(positional, required) => { if required { + if !sig.optional_positional.is_empty() { + error = error.or_else(|| { + Some(ParseError::RequiredAfterOptional( + positional.name.clone(), + span, + )) + }) + } sig.required_positional.push(positional) } else { sig.optional_positional.push(positional) @@ -3379,6 +3506,7 @@ pub fn parse_block_expression( name: "$it".into(), desc: String::new(), shape: SyntaxShape::Any, + default_value: None, }); output.signature = Box::new(signature); } @@ -4503,6 +4631,7 @@ fn wrap_expr_with_collect(working_set: &mut StateWorkingSet, expr: &Expression) name: "$in".into(), desc: String::new(), shape: SyntaxShape::Any, + default_value: None, }); let mut expr = expr.clone(); diff --git a/crates/nu-plugin/src/serializers/capnp/signature.rs b/crates/nu-plugin/src/serializers/capnp/signature.rs index 82be9eac0e..9d017f7d6a 100644 --- a/crates/nu-plugin/src/serializers/capnp/signature.rs +++ b/crates/nu-plugin/src/serializers/capnp/signature.rs @@ -218,6 +218,7 @@ fn deserialize_argument(reader: argument::Reader) -> Result Result { required, desc: desc.to_string(), var_id: None, + default_value: None, }) } diff --git a/crates/nu-protocol/src/ast/call.rs b/crates/nu-protocol/src/ast/call.rs index 116478860d..2d137e17c7 100644 --- a/crates/nu-protocol/src/ast/call.rs +++ b/crates/nu-protocol/src/ast/call.rs @@ -1,7 +1,9 @@ +use serde::{Deserialize, Serialize}; + use super::Expression; use crate::{DeclId, Span, Spanned}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Call { /// identifier of the declaration to call pub decl_id: DeclId, diff --git a/crates/nu-protocol/src/ast/cell_path.rs b/crates/nu-protocol/src/ast/cell_path.rs index 02f4310f56..8487abb800 100644 --- a/crates/nu-protocol/src/ast/cell_path.rs +++ b/crates/nu-protocol/src/ast/cell_path.rs @@ -41,7 +41,7 @@ impl CellPath { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct FullCellPath { pub head: Expression, pub tail: Vec, diff --git a/crates/nu-protocol/src/ast/expr.rs b/crates/nu-protocol/src/ast/expr.rs index 941cb3eb2a..17656065f3 100644 --- a/crates/nu-protocol/src/ast/expr.rs +++ b/crates/nu-protocol/src/ast/expr.rs @@ -1,9 +1,10 @@ use chrono::FixedOffset; +use serde::{Deserialize, Serialize}; use super::{Call, CellPath, Expression, FullCellPath, Operator, RangeOperator}; use crate::{ast::ImportPattern, BlockId, Signature, Span, Spanned, Unit, VarId}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Expr { Bool(bool), Int(i64), diff --git a/crates/nu-protocol/src/ast/expression.rs b/crates/nu-protocol/src/ast/expression.rs index a343310aa5..2aaa247328 100644 --- a/crates/nu-protocol/src/ast/expression.rs +++ b/crates/nu-protocol/src/ast/expression.rs @@ -1,8 +1,10 @@ +use serde::{Deserialize, Serialize}; + use super::{Expr, Operator}; use crate::ast::ImportPattern; use crate::{engine::StateWorkingSet, BlockId, Signature, Span, Type, VarId, IN_VARIABLE_ID}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Expression { pub expr: Expr, pub span: Span, diff --git a/crates/nu-protocol/src/ast/import_pattern.rs b/crates/nu-protocol/src/ast/import_pattern.rs index 6c361f666d..5b0568aa11 100644 --- a/crates/nu-protocol/src/ast/import_pattern.rs +++ b/crates/nu-protocol/src/ast/import_pattern.rs @@ -1,21 +1,23 @@ +use serde::{Deserialize, Serialize}; + use crate::{span, OverlayId, Span}; use std::collections::HashSet; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ImportPatternMember { Glob { span: Span }, Name { name: Vec, span: Span }, List { names: Vec<(Vec, Span)> }, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ImportPatternHead { pub name: Vec, pub id: Option, pub span: Span, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ImportPattern { pub head: ImportPatternHead, pub members: Vec, diff --git a/crates/nu-protocol/src/ast/operator.rs b/crates/nu-protocol/src/ast/operator.rs index 8291c61341..3c8e916dca 100644 --- a/crates/nu-protocol/src/ast/operator.rs +++ b/crates/nu-protocol/src/ast/operator.rs @@ -56,7 +56,7 @@ pub enum RangeInclusion { RightExclusive, } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)] pub struct RangeOperator { pub inclusion: RangeInclusion, pub span: Span, diff --git a/crates/nu-protocol/src/signature.rs b/crates/nu-protocol/src/signature.rs index bfbb438ae7..511e0534bb 100644 --- a/crates/nu-protocol/src/signature.rs +++ b/crates/nu-protocol/src/signature.rs @@ -2,6 +2,7 @@ use serde::Deserialize; use serde::Serialize; use crate::ast::Call; +use crate::ast::Expression; use crate::engine::Command; use crate::engine::EngineState; use crate::engine::Stack; @@ -10,24 +11,28 @@ use crate::PipelineData; use crate::SyntaxShape; use crate::VarId; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Flag { pub long: String, pub short: Option, pub arg: Option, pub required: bool, pub desc: String, + // For custom commands pub var_id: Option, + pub default_value: Option, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PositionalArg { pub name: String, pub desc: String, pub shape: SyntaxShape, + // For custom commands pub var_id: Option, + pub default_value: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -123,6 +128,7 @@ impl Signature { desc: "Display this help message".into(), required: false, var_id: None, + default_value: None, }; Signature { @@ -160,6 +166,7 @@ impl Signature { desc: desc.into(), shape: shape.into(), var_id: None, + default_value: None, }); self @@ -177,6 +184,7 @@ impl Signature { desc: desc.into(), shape: shape.into(), var_id: None, + default_value: None, }); self @@ -193,6 +201,7 @@ impl Signature { desc: desc.into(), shape: shape.into(), var_id: None, + default_value: None, }); self @@ -215,6 +224,7 @@ impl Signature { required: false, desc: desc.into(), var_id: None, + default_value: None, }); self @@ -237,6 +247,7 @@ impl Signature { required: true, desc: desc.into(), var_id: None, + default_value: None, }); self @@ -258,6 +269,7 @@ impl Signature { required: false, desc: desc.into(), var_id: None, + default_value: None, }); self diff --git a/crates/nu-protocol/src/span.rs b/crates/nu-protocol/src/span.rs index 282c19d1ce..391b8d7531 100644 --- a/crates/nu-protocol/src/span.rs +++ b/crates/nu-protocol/src/span.rs @@ -2,7 +2,7 @@ use miette::SourceSpan; use serde::{Deserialize, Serialize}; /// A spanned area of interest, generic over what kind of thing is of interest -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct Spanned where T: Clone + std::fmt::Debug, diff --git a/crates/nu-protocol/src/ty.rs b/crates/nu-protocol/src/ty.rs index a0b166c795..b9e8ecb557 100644 --- a/crates/nu-protocol/src/ty.rs +++ b/crates/nu-protocol/src/ty.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use std::fmt::Display; +use crate::SyntaxShape; + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum Type { Int, @@ -27,6 +29,34 @@ pub enum Type { Signature, } +impl Type { + pub fn to_shape(&self) -> SyntaxShape { + match self { + Type::Int => SyntaxShape::Int, + Type::Float => SyntaxShape::Number, + Type::Range => SyntaxShape::Range, + Type::Bool => SyntaxShape::Boolean, + Type::String => SyntaxShape::String, + Type::Block => SyntaxShape::Block(None), // FIXME needs more accuracy + Type::CellPath => SyntaxShape::CellPath, + Type::Duration => SyntaxShape::Duration, + Type::Date => SyntaxShape::DateTime, + Type::Filesize => SyntaxShape::Filesize, + Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())), + Type::Number => SyntaxShape::Number, + Type::Nothing => SyntaxShape::Any, + Type::Record(_) => SyntaxShape::Record, + Type::Table => SyntaxShape::Table, + Type::ListStream => SyntaxShape::List(Box::new(SyntaxShape::Any)), + Type::Unknown => SyntaxShape::Any, + Type::Error => SyntaxShape::Any, + Type::Binary => SyntaxShape::Binary, + Type::Custom => SyntaxShape::Custom(Box::new(SyntaxShape::Any), String::new()), + Type::Signature => SyntaxShape::Signature, + } + } +} + impl Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/nu-protocol/src/value/unit.rs b/crates/nu-protocol/src/value/unit.rs index 27fd893014..523798ed06 100644 --- a/crates/nu-protocol/src/value/unit.rs +++ b/crates/nu-protocol/src/value/unit.rs @@ -1,4 +1,6 @@ -#[derive(Debug, Clone, Copy)] +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] pub enum Unit { // Filesize units: metric Byte, diff --git a/crates/nu-protocol/tests/test_signature.rs b/crates/nu-protocol/tests/test_signature.rs index a2d1ce0f6e..99f50b60b4 100644 --- a/crates/nu-protocol/tests/test_signature.rs +++ b/crates/nu-protocol/tests/test_signature.rs @@ -46,7 +46,8 @@ fn test_signature_chained() { name: "required".to_string(), desc: "required description".to_string(), shape: SyntaxShape::String, - var_id: None + var_id: None, + default_value: None, }) ); assert_eq!( @@ -55,7 +56,8 @@ fn test_signature_chained() { name: "optional".to_string(), desc: "optional description".to_string(), shape: SyntaxShape::String, - var_id: None + var_id: None, + default_value: None, }) ); assert_eq!( @@ -64,7 +66,8 @@ fn test_signature_chained() { name: "rest".to_string(), desc: "rest description".to_string(), shape: SyntaxShape::String, - var_id: None + var_id: None, + default_value: None, }) ); @@ -76,7 +79,8 @@ fn test_signature_chained() { arg: Some(SyntaxShape::String), required: true, desc: "required named description".to_string(), - var_id: None + var_id: None, + default_value: None, }) ); @@ -88,7 +92,8 @@ fn test_signature_chained() { arg: Some(SyntaxShape::String), required: true, desc: "required named description".to_string(), - var_id: None + var_id: None, + default_value: None, }) ); } diff --git a/src/tests/test_engine.rs b/src/tests/test_engine.rs index ac0bc828a2..db9ca46bd8 100644 --- a/src/tests/test_engine.rs +++ b/src/tests/test_engine.rs @@ -287,3 +287,66 @@ fn bool_variable() -> TestResult { fn bool_variable2() -> TestResult { run_test(r#"$false"#, "false") } + +#[test] +fn default_value1() -> TestResult { + run_test(r#"def foo [x = 3] { $x }; foo"#, "3") +} + +#[test] +fn default_value2() -> TestResult { + run_test(r#"def foo [x: int = 3] { $x }; foo"#, "3") +} + +#[test] +fn default_value3() -> TestResult { + run_test(r#"def foo [--x = 3] { $x }; foo"#, "3") +} + +#[test] +fn default_value4() -> TestResult { + run_test(r#"def foo [--x: int = 3] { $x }; foo"#, "3") +} + +#[test] +fn default_value5() -> TestResult { + run_test(r#"def foo [x = 3] { $x }; foo 10"#, "10") +} + +#[test] +fn default_value6() -> TestResult { + run_test(r#"def foo [x: int = 3] { $x }; foo 10"#, "10") +} + +#[test] +fn default_value7() -> TestResult { + run_test(r#"def foo [--x = 3] { $x }; foo --x 10"#, "10") +} + +#[test] +fn default_value8() -> TestResult { + run_test(r#"def foo [--x: int = 3] { $x }; foo --x 10"#, "10") +} + +#[test] +fn default_value9() -> TestResult { + fail_test(r#"def foo [--x = 3] { $x }; foo --x a"#, "expected int") +} + +#[test] +fn default_value10() -> TestResult { + fail_test(r#"def foo [x = 3] { $x }; foo a"#, "expected int") +} + +#[test] +fn default_value11() -> TestResult { + fail_test( + r#"def foo [x = 3, y] { $x }; foo a"#, + "after optional parameter", + ) +} + +#[test] +fn default_value12() -> TestResult { + fail_test(r#"def foo [--x:int = "a"] { $x }"#, "default value not int") +}