Add 'did you mean' error (#305)

This commit is contained in:
JT 2021-11-08 10:48:50 +13:00 committed by GitHub
parent dd6452dfaa
commit 19645575d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 12 deletions

View file

@ -182,6 +182,10 @@ pub enum ShellError {
#[error("Plugin error")]
PluginError(String),
#[error("Name not found")]
#[diagnostic(code(nu::shell::name_not_found), url(docsrs))]
DidYouMean(String, #[label("did you mean '{0}'?")] Span),
}
impl From<std::io::Error> for ShellError {
@ -201,3 +205,82 @@ impl From<Box<dyn std::error::Error + Send + Sync>> for ShellError {
ShellError::InternalError(format!("{:?}", input))
}
}
pub fn did_you_mean(possibilities: &[String], tried: &str) -> Option<String> {
let mut possible_matches: Vec<_> = possibilities
.iter()
.map(|word| {
let edit_distance = levenshtein_distance(word, tried);
(edit_distance, word.to_owned())
})
.collect();
possible_matches.sort();
if let Some((_, first)) = possible_matches.into_iter().next() {
Some(first)
} else {
None
}
}
// Borrowed from here https://github.com/wooorm/levenshtein-rs
pub fn levenshtein_distance(a: &str, b: &str) -> usize {
let mut result = 0;
/* Shortcut optimizations / degenerate cases. */
if a == b {
return result;
}
let length_a = a.chars().count();
let length_b = b.chars().count();
if length_a == 0 {
return length_b;
}
if length_b == 0 {
return length_a;
}
/* Initialize the vector.
*
* This is why its fast, normally a matrix is used,
* here we use a single vector. */
let mut cache: Vec<usize> = (1..).take(length_a).collect();
let mut distance_a;
let mut distance_b;
/* Loop. */
for (index_b, code_b) in b.chars().enumerate() {
result = index_b;
distance_a = index_b;
for (index_a, code_a) in a.chars().enumerate() {
distance_b = if code_a == code_b {
distance_a
} else {
distance_a + 1
};
distance_a = cache[index_a];
result = if distance_a > result {
if distance_b > result {
result + 1
} else {
distance_b
}
} else if distance_b > distance_a {
distance_a + 1
} else {
distance_b
};
cache[index_a] = result;
}
}
result
}

View file

@ -15,7 +15,7 @@ use std::collections::HashMap;
use std::{cmp::Ordering, fmt::Debug};
use crate::ast::{CellPath, PathMember};
use crate::{span, BlockId, Span, Spanned, Type};
use crate::{did_you_mean, span, BlockId, Span, Spanned, Type};
use crate::ShellError;
@ -269,17 +269,16 @@ impl Value {
span: origin_span,
} => match &mut current {
Value::Record { cols, vals, span } => {
let cols = cols.clone();
let span = *span;
let mut found = false;
for col in cols.iter().zip(vals.iter()) {
if col.0 == column_name {
current = col.1.clone();
found = true;
break;
}
}
if !found {
if let Some(found) =
cols.iter().zip(vals.iter()).find(|x| x.0 == column_name)
{
current = found.1.clone();
} else if let Some(suggestion) = did_you_mean(&cols, column_name) {
return Err(ShellError::DidYouMean(suggestion, *origin_span));
} else {
return Err(ShellError::CantFindColumn(*origin_span, span));
}
}

View file

@ -730,7 +730,7 @@ fn earlier_errors() -> TestResult {
fn missing_column_error() -> TestResult {
fail_test(
r#"([([[name, size]; [ABC, 10], [DEF, 20]]).1, ([[name]; [HIJ]]).0]).size | table"#,
"cannot find column",
"did you mean 'name'?",
)
}
@ -829,7 +829,7 @@ fn shorthand_env_3() -> TestResult {
#[test]
fn shorthand_env_4() -> TestResult {
fail_test(r#"FOO=BAZ FOO= $nu.env.FOO"#, "cannot find column")
fail_test(r#"FOO=BAZ FOO= $nu.env.FOO"#, "did you mean")
}
#[test]