diff --git a/crates/ra_hir/src/from_source.rs b/crates/ra_hir/src/from_source.rs index f80d8eb5f1..a137aeb909 100644 --- a/crates/ra_hir/src/from_source.rs +++ b/crates/ra_hir/src/from_source.rs @@ -181,7 +181,6 @@ impl Module { ) -> Option { let decl_id = match src.ast { ModuleSource::Module(ref module) => { - assert!(!module.has_semi()); let ast_id_map = db.ast_id_map(src.file_id); let item_id = ast_id_map.ast_id(module).with_file_id(src.file_id); Some(item_id) diff --git a/crates/ra_ide_api/src/references.rs b/crates/ra_ide_api/src/references.rs index 0f1ac57fc1..3d647d2cbc 100644 --- a/crates/ra_ide_api/src/references.rs +++ b/crates/ra_ide_api/src/references.rs @@ -232,7 +232,7 @@ mod tests { } #[test] - fn test_find_all_refs_modules() { + fn test_find_all_refs_two_modules() { let code = r#" //- /lib.rs pub mod foo; @@ -270,6 +270,33 @@ mod tests { assert_eq!(refs.len(), 3); } + + // `mod foo;` is not in the results because `foo` is an `ast::Name`. + // So, there are two references: the first one is a definition of the `foo` module, + // which is the whole `foo.rs`, and the second one is in `use foo::Foo`. + #[test] + fn test_find_all_refs_decl_module() { + let code = r#" + //- /lib.rs + mod foo<|>; + + use foo::Foo; + + fn f() { + let i = Foo { n: 5 }; + } + + //- /foo.rs + pub struct Foo { + pub n: u32, + } + "#; + + let (analysis, pos) = analysis_and_position(code); + let refs = analysis.find_all_refs(pos).unwrap().unwrap(); + assert_eq!(refs.len(), 2); + } + fn get_all_refs(text: &str) -> ReferenceSearchResult { let (analysis, position) = single_file_with_position(text); analysis.find_all_refs(position).unwrap().unwrap() diff --git a/crates/ra_ide_api/src/references/classify.rs b/crates/ra_ide_api/src/references/classify.rs index 93e079cccb..ac9cf34eb0 100644 --- a/crates/ra_ide_api/src/references/classify.rs +++ b/crates/ra_ide_api/src/references/classify.rs @@ -31,9 +31,16 @@ pub(crate) fn classify_name( Some(from_struct_field(db, field)) }, ast::Module(it) => { - let ast = hir::ModuleSource::Module(it); - let src = hir::Source { file_id, ast }; - let def = hir::Module::from_definition(db, src)?; + let def = { + if !it.has_semi() { + let ast = hir::ModuleSource::Module(it); + let src = hir::Source { file_id, ast }; + hir::Module::from_definition(db, src) + } else { + let src = hir::Source { file_id, ast: it }; + hir::Module::from_declaration(db, src) + } + }?; Some(from_module_def(db, def.into(), None)) }, ast::StructDef(it) => { diff --git a/crates/ra_ide_api/src/references/search_scope.rs b/crates/ra_ide_api/src/references/search_scope.rs index 680988a214..d2c966b4f3 100644 --- a/crates/ra_ide_api/src/references/search_scope.rs +++ b/crates/ra_ide_api/src/references/search_scope.rs @@ -1,5 +1,7 @@ //! FIXME: write short doc here +use std::collections::HashSet; + use hir::{DefWithBody, HasSource, ModuleSource}; use ra_db::{FileId, SourceDatabase, SourceDatabaseExt}; use ra_syntax::{AstNode, TextRange}; @@ -9,7 +11,7 @@ use crate::db::RootDatabase; use super::{NameDefinition, NameKind}; impl NameDefinition { - pub(crate) fn search_scope(&self, db: &RootDatabase) -> Vec<(FileId, Option)> { + pub(crate) fn search_scope(&self, db: &RootDatabase) -> HashSet<(FileId, Option)> { let module_src = self.container.definition_source(db); let file_id = module_src.file_id.original_file(db); @@ -19,13 +21,13 @@ impl NameDefinition { DefWithBody::Const(c) => c.source(db).ast.syntax().text_range(), DefWithBody::Static(s) => s.source(db).ast.syntax().text_range(), }; - return vec![(file_id, Some(range))]; + return [(file_id, Some(range))].iter().cloned().collect(); } if let Some(ref vis) = self.visibility { let source_root_id = db.file_source_root(file_id); let source_root = db.source_root(source_root_id); - let mut files = source_root.walk().map(|id| (id.into(), None)).collect::>(); + let mut files = source_root.walk().map(|id| (id.into(), None)).collect::>(); if vis.syntax().to_string().as_str() == "pub(crate)" { return files; @@ -54,6 +56,6 @@ impl NameDefinition { ModuleSource::Module(m) => Some(m.syntax().text_range()), ModuleSource::SourceFile(_) => None, }; - vec![(file_id, range)] + [(file_id, range)].iter().cloned().collect() } }