diff --git a/crates/ra_ide/src/references.rs b/crates/ra_ide/src/references.rs index 5e2fe1905a..ebded715d6 100644 --- a/crates/ra_ide/src/references.rs +++ b/crates/ra_ide/src/references.rs @@ -112,25 +112,20 @@ impl IntoIterator for ReferenceSearchResult { pub(crate) fn find_all_refs( db: &RootDatabase, - mut position: FilePosition, + position: FilePosition, search_scope: Option, ) -> Option> { let parse = db.parse(position.file_id); let syntax = parse.tree().syntax().clone(); - let token = syntax.token_at_offset(position.offset); - let mut search_kind = ReferenceKind::Other; + let (opt_name, search_kind) = + if let Some(name) = get_struct_def_name_for_struc_litetal_search(&syntax, position) { + (Some(name), ReferenceKind::StructLiteral) + } else { + (find_node_at_offset::(&syntax, position.offset), ReferenceKind::Other) + }; - if let TokenAtOffset::Between(ref left, ref right) = token { - if (right.kind() == SyntaxKind::L_CURLY || right.kind() == SyntaxKind::L_PAREN) - && left.kind() != SyntaxKind::IDENT - { - position = FilePosition { offset: left.text_range().start(), ..position }; - search_kind = ReferenceKind::StructLiteral; - } - } - - let RangeInfo { range, info: (name, def) } = find_name(db, &syntax, position)?; + let RangeInfo { range, info: (name, def) } = find_name(db, &syntax, position, opt_name)?; let declaration = match def.kind { NameKind::Macro(mac) => mac.to_nav(db), @@ -170,9 +165,10 @@ fn find_name( db: &RootDatabase, syntax: &SyntaxNode, position: FilePosition, + opt_name: Option, ) -> Option> { let mut sb = SourceBinder::new(db); - if let Some(name) = find_node_at_offset::(&syntax, position.offset) { + if let Some(name) = opt_name { let def = classify_name(&mut sb, InFile::new(position.file_id.into(), &name))?; let range = name.syntax().text_range(); return Some(RangeInfo::new(range, (name.text().to_string(), def))); @@ -218,15 +214,8 @@ fn process_definition( if let Some(d) = classify_name_ref(&mut sb, InFile::new(file_id.into(), &name_ref)) { if d == def { - let kind = if name_ref - .syntax() - .ancestors() - .find_map(ast::RecordLit::cast) - .and_then(|l| l.path()) - .and_then(|p| p.segment()) - .and_then(|p| p.name_ref()) - .map(|n| n == name_ref) - .unwrap_or(false) + let kind = if is_record_lit_name_ref(&name_ref) + || is_call_expr_name_ref(&name_ref) { ReferenceKind::StructLiteral } else { @@ -301,6 +290,49 @@ fn reference_access(kind: &NameKind, name_ref: &ast::NameRef) -> Option bool { + name_ref + .syntax() + .ancestors() + .find_map(ast::RecordLit::cast) + .and_then(|l| l.path()) + .and_then(|p| p.segment()) + .map(|p| p.name_ref().as_ref() == Some(name_ref)) + .unwrap_or(false) +} + +fn get_struct_def_name_for_struc_litetal_search( + syntax: &SyntaxNode, + position: FilePosition, +) -> Option { + if let TokenAtOffset::Between(ref left, ref right) = syntax.token_at_offset(position.offset) { + if right.kind() != SyntaxKind::L_CURLY && right.kind() != SyntaxKind::L_PAREN { + return None; + } + if let Some(name) = find_node_at_offset::(&syntax, left.text_range().start()) { + return name.syntax().ancestors().find_map(ast::StructDef::cast).and_then(|l| l.name()); + } + if find_node_at_offset::(&syntax, left.text_range().start()).is_some() { + return left.ancestors().find_map(ast::StructDef::cast).and_then(|l| l.name()); + } + } + None +} + +fn is_call_expr_name_ref(name_ref: &ast::NameRef) -> bool { + name_ref + .syntax() + .ancestors() + .find_map(ast::CallExpr::cast) + .and_then(|c| match c.expr()? { + ast::Expr::PathExpr(p) => { + Some(p.path()?.segment()?.name_ref().as_ref() == Some(name_ref)) + } + _ => None, + }) + .unwrap_or(false) +} + #[cfg(test)] mod tests { use crate::{ @@ -309,7 +341,7 @@ mod tests { }; #[test] - fn test_struct_literal() { + fn test_struct_literal_after_space() { let code = r#" struct Foo <|>{ a: i32, @@ -330,6 +362,58 @@ mod tests { ); } + #[test] + fn test_struct_literal_befor_space() { + let code = r#" + struct Foo<|> {} + fn main() { + let f: Foo; + f = Foo {}; + }"#; + + let refs = get_all_refs(code); + check_result( + refs, + "Foo STRUCT_DEF FileId(1) [5; 18) [12; 15) Other", + &["FileId(1) [54; 57) Other", "FileId(1) [71; 74) StructLiteral"], + ); + } + + #[test] + fn test_struct_literal_with_generic_type() { + let code = r#" + struct Foo <|>{} + fn main() { + let f: Foo::; + f = Foo {}; + }"#; + + let refs = get_all_refs(code); + check_result( + refs, + "Foo STRUCT_DEF FileId(1) [5; 21) [12; 15) Other", + &["FileId(1) [81; 84) StructLiteral"], + ); + } + + #[test] + fn test_struct_literal_for_tuple() { + let code = r#" + struct Foo<|>(i32); + + fn main() { + let f: Foo; + f = Foo(1); + }"#; + + let refs = get_all_refs(code); + check_result( + refs, + "Foo STRUCT_DEF FileId(1) [5; 21) [12; 15) Other", + &["FileId(1) [71; 74) StructLiteral"], + ); + } + #[test] fn test_find_all_refs_for_local() { let code = r#" @@ -564,7 +648,7 @@ mod tests { check_result( refs, "quux FN_DEF FileId(1) [18; 34) [25; 29) Other", - &["FileId(2) [16; 20) Other", "FileId(3) [16; 20) Other"], + &["FileId(2) [16; 20) StructLiteral", "FileId(3) [16; 20) StructLiteral"], ); let refs = @@ -572,7 +656,7 @@ mod tests { check_result( refs, "quux FN_DEF FileId(1) [18; 34) [25; 29) Other", - &["FileId(3) [16; 20) Other"], + &["FileId(3) [16; 20) StructLiteral"], ); } @@ -591,7 +675,7 @@ mod tests { check_result( refs, "m1 MACRO_CALL FileId(1) [9; 63) [46; 48) Other", - &["FileId(1) [96; 98) Other", "FileId(1) [114; 116) Other"], + &["FileId(1) [96; 98) StructLiteral", "FileId(1) [114; 116) StructLiteral"], ); }