Add BindingId

This commit is contained in:
hkalbasi 2023-02-19 00:02:55 +03:30
parent a360fab9a3
commit 61ad6a96ad
27 changed files with 514 additions and 297 deletions

View file

@ -24,7 +24,7 @@ use syntax::{ast, AstPtr, SyntaxNode, SyntaxNodePtr};
use crate::{ use crate::{
attr::Attrs, attr::Attrs,
db::DefDatabase, db::DefDatabase,
expr::{dummy_expr_id, Expr, ExprId, Label, LabelId, Pat, PatId}, expr::{dummy_expr_id, Binding, BindingId, Expr, ExprId, Label, LabelId, Pat, PatId},
item_scope::BuiltinShadowMode, item_scope::BuiltinShadowMode,
macro_id_to_def_id, macro_id_to_def_id,
nameres::DefMap, nameres::DefMap,
@ -270,6 +270,7 @@ pub struct Mark {
pub struct Body { pub struct Body {
pub exprs: Arena<Expr>, pub exprs: Arena<Expr>,
pub pats: Arena<Pat>, pub pats: Arena<Pat>,
pub bindings: Arena<Binding>,
pub or_pats: FxHashMap<PatId, Arc<[PatId]>>, pub or_pats: FxHashMap<PatId, Arc<[PatId]>>,
pub labels: Arena<Label>, pub labels: Arena<Label>,
/// The patterns for the function's parameters. While the parameter types are /// The patterns for the function's parameters. While the parameter types are
@ -435,13 +436,24 @@ impl Body {
} }
fn shrink_to_fit(&mut self) { fn shrink_to_fit(&mut self) {
let Self { _c: _, body_expr: _, block_scopes, or_pats, exprs, labels, params, pats } = self; let Self {
_c: _,
body_expr: _,
block_scopes,
or_pats,
exprs,
labels,
params,
pats,
bindings,
} = self;
block_scopes.shrink_to_fit(); block_scopes.shrink_to_fit();
or_pats.shrink_to_fit(); or_pats.shrink_to_fit();
exprs.shrink_to_fit(); exprs.shrink_to_fit();
labels.shrink_to_fit(); labels.shrink_to_fit();
params.shrink_to_fit(); params.shrink_to_fit();
pats.shrink_to_fit(); pats.shrink_to_fit();
bindings.shrink_to_fit();
} }
} }
@ -451,6 +463,7 @@ impl Default for Body {
body_expr: dummy_expr_id(), body_expr: dummy_expr_id(),
exprs: Default::default(), exprs: Default::default(),
pats: Default::default(), pats: Default::default(),
bindings: Default::default(),
or_pats: Default::default(), or_pats: Default::default(),
labels: Default::default(), labels: Default::default(),
params: Default::default(), params: Default::default(),
@ -484,6 +497,14 @@ impl Index<LabelId> for Body {
} }
} }
impl Index<BindingId> for Body {
type Output = Binding;
fn index(&self, b: BindingId) -> &Binding {
&self.bindings[b]
}
}
// FIXME: Change `node_` prefix to something more reasonable. // FIXME: Change `node_` prefix to something more reasonable.
// Perhaps `expr_syntax` and `expr_id`? // Perhaps `expr_syntax` and `expr_id`?
impl BodySourceMap { impl BodySourceMap {

View file

@ -15,6 +15,7 @@ use la_arena::Arena;
use once_cell::unsync::OnceCell; use once_cell::unsync::OnceCell;
use profile::Count; use profile::Count;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use syntax::{ use syntax::{
ast::{ ast::{
self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind, self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind,
@ -30,9 +31,9 @@ use crate::{
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint}, builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint},
db::DefDatabase, db::DefDatabase,
expr::{ expr::{
dummy_expr_id, Array, BindingAnnotation, ClosureKind, Expr, ExprId, FloatTypeWrapper, dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, ClosureKind, Expr, ExprId,
Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, RecordLitField, FloatTypeWrapper, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId,
Statement, RecordFieldPat, RecordLitField, Statement,
}, },
item_scope::BuiltinShadowMode, item_scope::BuiltinShadowMode,
path::{GenericArgs, Path}, path::{GenericArgs, Path},
@ -87,6 +88,7 @@ pub(super) fn lower(
body: Body { body: Body {
exprs: Arena::default(), exprs: Arena::default(),
pats: Arena::default(), pats: Arena::default(),
bindings: Arena::default(),
labels: Arena::default(), labels: Arena::default(),
params: Vec::new(), params: Vec::new(),
body_expr: dummy_expr_id(), body_expr: dummy_expr_id(),
@ -116,6 +118,22 @@ struct ExprCollector<'a> {
is_lowering_generator: bool, is_lowering_generator: bool,
} }
#[derive(Debug, Default)]
struct BindingList {
map: FxHashMap<Name, BindingId>,
}
impl BindingList {
fn find(
&mut self,
ec: &mut ExprCollector<'_>,
name: Name,
mode: BindingAnnotation,
) -> BindingId {
*self.map.entry(name).or_insert_with_key(|n| ec.alloc_binding(n.clone(), mode))
}
}
impl ExprCollector<'_> { impl ExprCollector<'_> {
fn collect( fn collect(
mut self, mut self,
@ -127,17 +145,16 @@ impl ExprCollector<'_> {
param_list.self_param().filter(|_| attr_enabled.next().unwrap_or(false)) param_list.self_param().filter(|_| attr_enabled.next().unwrap_or(false))
{ {
let ptr = AstPtr::new(&self_param); let ptr = AstPtr::new(&self_param);
let param_pat = self.alloc_pat( let binding_id = self.alloc_binding(
Pat::Bind { name![self],
name: name![self], BindingAnnotation::new(
mode: BindingAnnotation::new(
self_param.mut_token().is_some() && self_param.amp_token().is_none(), self_param.mut_token().is_some() && self_param.amp_token().is_none(),
false, false,
), ),
subpat: None,
},
Either::Right(ptr),
); );
let param_pat =
self.alloc_pat(Pat::Bind { id: binding_id, subpat: None }, Either::Right(ptr));
self.add_definition_to_binding(binding_id, param_pat);
self.body.params.push(param_pat); self.body.params.push(param_pat);
} }
@ -179,6 +196,9 @@ impl ExprCollector<'_> {
id id
} }
fn alloc_binding(&mut self, name: Name, mode: BindingAnnotation) -> BindingId {
self.body.bindings.alloc(Binding { name, mode, definitions: SmallVec::new() })
}
fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId { fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId {
let src = self.expander.to_source(ptr); let src = self.expander.to_source(ptr);
let id = self.make_pat(pat, src.clone()); let id = self.make_pat(pat, src.clone());
@ -804,7 +824,7 @@ impl ExprCollector<'_> {
} }
fn collect_pat(&mut self, pat: ast::Pat) -> PatId { fn collect_pat(&mut self, pat: ast::Pat) -> PatId {
let pat_id = self.collect_pat_(pat); let pat_id = self.collect_pat_(pat, &mut BindingList::default());
for (_, pats) in self.name_to_pat_grouping.drain() { for (_, pats) in self.name_to_pat_grouping.drain() {
let pats = Arc::<[_]>::from(pats); let pats = Arc::<[_]>::from(pats);
self.body.or_pats.extend(pats.iter().map(|&pat| (pat, pats.clone()))); self.body.or_pats.extend(pats.iter().map(|&pat| (pat, pats.clone())));
@ -820,7 +840,7 @@ impl ExprCollector<'_> {
} }
} }
fn collect_pat_(&mut self, pat: ast::Pat) -> PatId { fn collect_pat_(&mut self, pat: ast::Pat, binding_list: &mut BindingList) -> PatId {
let pattern = match &pat { let pattern = match &pat {
ast::Pat::IdentPat(bp) => { ast::Pat::IdentPat(bp) => {
let name = bp.name().map(|nr| nr.as_name()).unwrap_or_else(Name::missing); let name = bp.name().map(|nr| nr.as_name()).unwrap_or_else(Name::missing);
@ -828,8 +848,10 @@ impl ExprCollector<'_> {
let key = self.is_lowering_inside_or_pat.then(|| name.clone()); let key = self.is_lowering_inside_or_pat.then(|| name.clone());
let annotation = let annotation =
BindingAnnotation::new(bp.mut_token().is_some(), bp.ref_token().is_some()); BindingAnnotation::new(bp.mut_token().is_some(), bp.ref_token().is_some());
let subpat = bp.pat().map(|subpat| self.collect_pat_(subpat)); let subpat = bp.pat().map(|subpat| self.collect_pat_(subpat, binding_list));
let pattern = if annotation == BindingAnnotation::Unannotated && subpat.is_none() { let (binding, pattern) = if annotation == BindingAnnotation::Unannotated
&& subpat.is_none()
{
// This could also be a single-segment path pattern. To // This could also be a single-segment path pattern. To
// decide that, we need to try resolving the name. // decide that, we need to try resolving the name.
let (resolved, _) = self.expander.def_map.resolve_path( let (resolved, _) = self.expander.def_map.resolve_path(
@ -839,12 +861,12 @@ impl ExprCollector<'_> {
BuiltinShadowMode::Other, BuiltinShadowMode::Other,
); );
match resolved.take_values() { match resolved.take_values() {
Some(ModuleDefId::ConstId(_)) => Pat::Path(name.into()), Some(ModuleDefId::ConstId(_)) => (None, Pat::Path(name.into())),
Some(ModuleDefId::EnumVariantId(_)) => { Some(ModuleDefId::EnumVariantId(_)) => {
// this is only really valid for unit variants, but // this is only really valid for unit variants, but
// shadowing other enum variants with a pattern is // shadowing other enum variants with a pattern is
// an error anyway // an error anyway
Pat::Path(name.into()) (None, Pat::Path(name.into()))
} }
Some(ModuleDefId::AdtId(AdtId::StructId(s))) Some(ModuleDefId::AdtId(AdtId::StructId(s)))
if self.db.struct_data(s).variant_data.kind() != StructKind::Record => if self.db.struct_data(s).variant_data.kind() != StructKind::Record =>
@ -852,17 +874,24 @@ impl ExprCollector<'_> {
// Funnily enough, record structs *can* be shadowed // Funnily enough, record structs *can* be shadowed
// by pattern bindings (but unit or tuple structs // by pattern bindings (but unit or tuple structs
// can't). // can't).
Pat::Path(name.into()) (None, Pat::Path(name.into()))
} }
// shadowing statics is an error as well, so we just ignore that case here // shadowing statics is an error as well, so we just ignore that case here
_ => Pat::Bind { name, mode: annotation, subpat }, _ => {
let id = binding_list.find(self, name, annotation);
(Some(id), Pat::Bind { id, subpat })
}
} }
} else { } else {
Pat::Bind { name, mode: annotation, subpat } let id = binding_list.find(self, name, annotation);
(Some(id), Pat::Bind { id, subpat })
}; };
let ptr = AstPtr::new(&pat); let ptr = AstPtr::new(&pat);
let pat = self.alloc_pat(pattern, Either::Left(ptr)); let pat = self.alloc_pat(pattern, Either::Left(ptr));
if let Some(binding_id) = binding {
self.add_definition_to_binding(binding_id, pat);
}
if let Some(key) = key { if let Some(key) = key {
self.name_to_pat_grouping.entry(key).or_default().push(pat); self.name_to_pat_grouping.entry(key).or_default().push(pat);
} }
@ -871,11 +900,11 @@ impl ExprCollector<'_> {
ast::Pat::TupleStructPat(p) => { ast::Pat::TupleStructPat(p) => {
let path = let path =
p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new); p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new);
let (args, ellipsis) = self.collect_tuple_pat(p.fields()); let (args, ellipsis) = self.collect_tuple_pat(p.fields(), binding_list);
Pat::TupleStruct { path, args, ellipsis } Pat::TupleStruct { path, args, ellipsis }
} }
ast::Pat::RefPat(p) => { ast::Pat::RefPat(p) => {
let pat = self.collect_pat_opt(p.pat()); let pat = self.collect_pat_opt_(p.pat(), binding_list);
let mutability = Mutability::from_mutable(p.mut_token().is_some()); let mutability = Mutability::from_mutable(p.mut_token().is_some());
Pat::Ref { pat, mutability } Pat::Ref { pat, mutability }
} }
@ -886,12 +915,12 @@ impl ExprCollector<'_> {
} }
ast::Pat::OrPat(p) => { ast::Pat::OrPat(p) => {
self.is_lowering_inside_or_pat = true; self.is_lowering_inside_or_pat = true;
let pats = p.pats().map(|p| self.collect_pat_(p)).collect(); let pats = p.pats().map(|p| self.collect_pat_(p, binding_list)).collect();
Pat::Or(pats) Pat::Or(pats)
} }
ast::Pat::ParenPat(p) => return self.collect_pat_opt_(p.pat()), ast::Pat::ParenPat(p) => return self.collect_pat_opt_(p.pat(), binding_list),
ast::Pat::TuplePat(p) => { ast::Pat::TuplePat(p) => {
let (args, ellipsis) = self.collect_tuple_pat(p.fields()); let (args, ellipsis) = self.collect_tuple_pat(p.fields(), binding_list);
Pat::Tuple { args, ellipsis } Pat::Tuple { args, ellipsis }
} }
ast::Pat::WildcardPat(_) => Pat::Wild, ast::Pat::WildcardPat(_) => Pat::Wild,
@ -904,7 +933,7 @@ impl ExprCollector<'_> {
.fields() .fields()
.filter_map(|f| { .filter_map(|f| {
let ast_pat = f.pat()?; let ast_pat = f.pat()?;
let pat = self.collect_pat_(ast_pat); let pat = self.collect_pat_(ast_pat, binding_list);
let name = f.field_name()?.as_name(); let name = f.field_name()?.as_name();
Some(RecordFieldPat { name, pat }) Some(RecordFieldPat { name, pat })
}) })
@ -923,9 +952,15 @@ impl ExprCollector<'_> {
// FIXME properly handle `RestPat` // FIXME properly handle `RestPat`
Pat::Slice { Pat::Slice {
prefix: prefix.into_iter().map(|p| self.collect_pat_(p)).collect(), prefix: prefix
slice: slice.map(|p| self.collect_pat_(p)), .into_iter()
suffix: suffix.into_iter().map(|p| self.collect_pat_(p)).collect(), .map(|p| self.collect_pat_(p, binding_list))
.collect(),
slice: slice.map(|p| self.collect_pat_(p, binding_list)),
suffix: suffix
.into_iter()
.map(|p| self.collect_pat_(p, binding_list))
.collect(),
} }
} }
ast::Pat::LiteralPat(lit) => { ast::Pat::LiteralPat(lit) => {
@ -948,7 +983,7 @@ impl ExprCollector<'_> {
Pat::Missing Pat::Missing
} }
ast::Pat::BoxPat(boxpat) => { ast::Pat::BoxPat(boxpat) => {
let inner = self.collect_pat_opt_(boxpat.pat()); let inner = self.collect_pat_opt_(boxpat.pat(), binding_list);
Pat::Box { inner } Pat::Box { inner }
} }
ast::Pat::ConstBlockPat(const_block_pat) => { ast::Pat::ConstBlockPat(const_block_pat) => {
@ -965,7 +1000,7 @@ impl ExprCollector<'_> {
let src = self.expander.to_source(Either::Left(AstPtr::new(&pat))); let src = self.expander.to_source(Either::Left(AstPtr::new(&pat)));
let pat = let pat =
self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| { self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| {
this.collect_pat_opt_(expanded_pat) this.collect_pat_opt_(expanded_pat, binding_list)
}); });
self.source_map.pat_map.insert(src, pat); self.source_map.pat_map.insert(src, pat);
return pat; return pat;
@ -979,21 +1014,25 @@ impl ExprCollector<'_> {
self.alloc_pat(pattern, Either::Left(ptr)) self.alloc_pat(pattern, Either::Left(ptr))
} }
fn collect_pat_opt_(&mut self, pat: Option<ast::Pat>) -> PatId { fn collect_pat_opt_(&mut self, pat: Option<ast::Pat>, binding_list: &mut BindingList) -> PatId {
match pat { match pat {
Some(pat) => self.collect_pat_(pat), Some(pat) => self.collect_pat_(pat, binding_list),
None => self.missing_pat(), None => self.missing_pat(),
} }
} }
fn collect_tuple_pat(&mut self, args: AstChildren<ast::Pat>) -> (Box<[PatId]>, Option<usize>) { fn collect_tuple_pat(
&mut self,
args: AstChildren<ast::Pat>,
binding_list: &mut BindingList,
) -> (Box<[PatId]>, Option<usize>) {
// Find the location of the `..`, if there is one. Note that we do not // Find the location of the `..`, if there is one. Note that we do not
// consider the possibility of there being multiple `..` here. // consider the possibility of there being multiple `..` here.
let ellipsis = args.clone().position(|p| matches!(p, ast::Pat::RestPat(_))); let ellipsis = args.clone().position(|p| matches!(p, ast::Pat::RestPat(_)));
// We want to skip the `..` pattern here, since we account for it above. // We want to skip the `..` pattern here, since we account for it above.
let args = args let args = args
.filter(|p| !matches!(p, ast::Pat::RestPat(_))) .filter(|p| !matches!(p, ast::Pat::RestPat(_)))
.map(|p| self.collect_pat_(p)) .map(|p| self.collect_pat_(p, binding_list))
.collect(); .collect();
(args, ellipsis) (args, ellipsis)
@ -1022,6 +1061,10 @@ impl ExprCollector<'_> {
None => Some(()), None => Some(()),
} }
} }
fn add_definition_to_binding(&mut self, binding_id: BindingId, pat_id: PatId) {
self.body.bindings[binding_id].definitions.push(pat_id);
}
} }
impl From<ast::LiteralKind> for Literal { impl From<ast::LiteralKind> for Literal {

View file

@ -5,7 +5,7 @@ use std::fmt::{self, Write};
use syntax::ast::HasName; use syntax::ast::HasName;
use crate::{ use crate::{
expr::{Array, BindingAnnotation, ClosureKind, Literal, Movability, Statement}, expr::{Array, BindingAnnotation, BindingId, ClosureKind, Literal, Movability, Statement},
pretty::{print_generic_args, print_path, print_type_ref}, pretty::{print_generic_args, print_path, print_type_ref},
type_ref::TypeRef, type_ref::TypeRef,
}; };
@ -524,14 +524,8 @@ impl<'a> Printer<'a> {
} }
Pat::Path(path) => self.print_path(path), Pat::Path(path) => self.print_path(path),
Pat::Lit(expr) => self.print_expr(*expr), Pat::Lit(expr) => self.print_expr(*expr),
Pat::Bind { mode, name, subpat } => { Pat::Bind { id, subpat } => {
let mode = match mode { self.print_binding(*id);
BindingAnnotation::Unannotated => "",
BindingAnnotation::Mutable => "mut ",
BindingAnnotation::Ref => "ref ",
BindingAnnotation::RefMut => "ref mut ",
};
w!(self, "{}{}", mode, name);
if let Some(pat) = subpat { if let Some(pat) = subpat {
self.whitespace(); self.whitespace();
self.print_pat(*pat); self.print_pat(*pat);
@ -635,4 +629,15 @@ impl<'a> Printer<'a> {
fn print_path(&mut self, path: &Path) { fn print_path(&mut self, path: &Path) {
print_path(path, self).unwrap(); print_path(path, self).unwrap();
} }
fn print_binding(&mut self, id: BindingId) {
let Binding { name, mode, .. } = &self.body.bindings[id];
let mode = match mode {
BindingAnnotation::Unannotated => "",
BindingAnnotation::Mutable => "mut ",
BindingAnnotation::Ref => "ref ",
BindingAnnotation::RefMut => "ref mut ",
};
w!(self, "{}{}", mode, name);
}
} }

View file

@ -8,7 +8,7 @@ use rustc_hash::FxHashMap;
use crate::{ use crate::{
body::Body, body::Body,
db::DefDatabase, db::DefDatabase,
expr::{Expr, ExprId, LabelId, Pat, PatId, Statement}, expr::{Binding, BindingId, Expr, ExprId, LabelId, Pat, PatId, Statement},
BlockId, DefWithBodyId, BlockId, DefWithBodyId,
}; };
@ -23,7 +23,7 @@ pub struct ExprScopes {
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct ScopeEntry { pub struct ScopeEntry {
name: Name, name: Name,
pat: PatId, binding: BindingId,
} }
impl ScopeEntry { impl ScopeEntry {
@ -31,8 +31,8 @@ impl ScopeEntry {
&self.name &self.name
} }
pub fn pat(&self) -> PatId { pub fn binding(&self) -> BindingId {
self.pat self.binding
} }
} }
@ -126,18 +126,23 @@ impl ExprScopes {
}) })
} }
fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) { fn add_bindings(&mut self, body: &Body, scope: ScopeId, binding: BindingId) {
let pattern = &body[pat]; let Binding { name, .. } = &body.bindings[binding];
if let Pat::Bind { name, .. } = pattern { let entry = ScopeEntry { name: name.clone(), binding };
let entry = ScopeEntry { name: name.clone(), pat };
self.scopes[scope].entries.push(entry); self.scopes[scope].entries.push(entry);
} }
pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat)); fn add_pat_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
let pattern = &body[pat];
if let Pat::Bind { id, .. } = pattern {
self.add_bindings(body, scope, *id);
}
pattern.walk_child_pats(|pat| self.add_pat_bindings(body, scope, pat));
} }
fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) { fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
params.iter().for_each(|pat| self.add_bindings(body, scope, *pat)); params.iter().for_each(|pat| self.add_pat_bindings(body, scope, *pat));
} }
fn set_scope(&mut self, node: ExprId, scope: ScopeId) { fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
@ -170,7 +175,7 @@ fn compute_block_scopes(
} }
*scope = scopes.new_scope(*scope); *scope = scopes.new_scope(*scope);
scopes.add_bindings(body, *scope, *pat); scopes.add_pat_bindings(body, *scope, *pat);
} }
Statement::Expr { expr, .. } => { Statement::Expr { expr, .. } => {
compute_expr_scopes(*expr, body, scopes, scope); compute_expr_scopes(*expr, body, scopes, scope);
@ -208,7 +213,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
Expr::For { iterable, pat, body: body_expr, label } => { Expr::For { iterable, pat, body: body_expr, label } => {
compute_expr_scopes(*iterable, body, scopes, scope); compute_expr_scopes(*iterable, body, scopes, scope);
let mut scope = scopes.new_labeled_scope(*scope, make_label(label)); let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
scopes.add_bindings(body, scope, *pat); scopes.add_pat_bindings(body, scope, *pat);
compute_expr_scopes(*body_expr, body, scopes, &mut scope); compute_expr_scopes(*body_expr, body, scopes, &mut scope);
} }
Expr::While { condition, body: body_expr, label } => { Expr::While { condition, body: body_expr, label } => {
@ -229,7 +234,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
compute_expr_scopes(*expr, body, scopes, scope); compute_expr_scopes(*expr, body, scopes, scope);
for arm in arms.iter() { for arm in arms.iter() {
let mut scope = scopes.new_scope(*scope); let mut scope = scopes.new_scope(*scope);
scopes.add_bindings(body, scope, arm.pat); scopes.add_pat_bindings(body, scope, arm.pat);
if let Some(guard) = arm.guard { if let Some(guard) = arm.guard {
scope = scopes.new_scope(scope); scope = scopes.new_scope(scope);
compute_expr_scopes(guard, body, scopes, &mut scope); compute_expr_scopes(guard, body, scopes, &mut scope);
@ -248,7 +253,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
&Expr::Let { pat, expr } => { &Expr::Let { pat, expr } => {
compute_expr_scopes(expr, body, scopes, scope); compute_expr_scopes(expr, body, scopes, scope);
*scope = scopes.new_scope(*scope); *scope = scopes.new_scope(*scope);
scopes.add_bindings(body, *scope, pat); scopes.add_pat_bindings(body, *scope, pat);
} }
e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)), e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
}; };
@ -450,7 +455,7 @@ fn foo() {
let function = find_function(&db, file_id); let function = find_function(&db, file_id);
let scopes = db.expr_scopes(function.into()); let scopes = db.expr_scopes(function.into());
let (_body, source_map) = db.body_with_source_map(function.into()); let (body, source_map) = db.body_with_source_map(function.into());
let expr_scope = { let expr_scope = {
let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap(); let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
@ -460,7 +465,9 @@ fn foo() {
}; };
let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap(); let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
let pat_src = source_map.pat_syntax(resolved.pat()).unwrap(); let pat_src = source_map
.pat_syntax(*body.bindings[resolved.binding()].definitions.first().unwrap())
.unwrap();
let local_name = pat_src.value.either( let local_name = pat_src.value.either(
|it| it.syntax_node_ptr().to_node(file.syntax()), |it| it.syntax_node_ptr().to_node(file.syntax()),

View file

@ -17,6 +17,7 @@ use std::fmt;
use hir_expand::name::Name; use hir_expand::name::Name;
use intern::Interned; use intern::Interned;
use la_arena::{Idx, RawIdx}; use la_arena::{Idx, RawIdx};
use smallvec::SmallVec;
use crate::{ use crate::{
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint}, builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint},
@ -29,6 +30,8 @@ pub use syntax::ast::{ArithOp, BinaryOp, CmpOp, LogicOp, Ordering, RangeOp, Unar
pub type ExprId = Idx<Expr>; pub type ExprId = Idx<Expr>;
pub type BindingId = Idx<Binding>;
/// FIXME: this is a hacky function which should be removed /// FIXME: this is a hacky function which should be removed
pub(crate) fn dummy_expr_id() -> ExprId { pub(crate) fn dummy_expr_id() -> ExprId {
ExprId::from_raw(RawIdx::from(u32::MAX)) ExprId::from_raw(RawIdx::from(u32::MAX))
@ -433,6 +436,13 @@ impl BindingAnnotation {
} }
} }
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Binding {
pub name: Name,
pub mode: BindingAnnotation,
pub definitions: SmallVec<[PatId; 1]>,
}
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
pub struct RecordFieldPat { pub struct RecordFieldPat {
pub name: Name, pub name: Name,
@ -451,7 +461,7 @@ pub enum Pat {
Slice { prefix: Box<[PatId]>, slice: Option<PatId>, suffix: Box<[PatId]> }, Slice { prefix: Box<[PatId]>, slice: Option<PatId>, suffix: Box<[PatId]> },
Path(Box<Path>), Path(Box<Path>),
Lit(ExprId), Lit(ExprId),
Bind { mode: BindingAnnotation, name: Name, subpat: Option<PatId> }, Bind { id: BindingId, subpat: Option<PatId> },
TupleStruct { path: Option<Box<Path>>, args: Box<[PatId]>, ellipsis: Option<usize> }, TupleStruct { path: Option<Box<Path>>, args: Box<[PatId]>, ellipsis: Option<usize> },
Ref { pat: PatId, mutability: Mutability }, Ref { pat: PatId, mutability: Mutability },
Box { inner: PatId }, Box { inner: PatId },

View file

@ -12,7 +12,7 @@ use crate::{
body::scope::{ExprScopes, ScopeId}, body::scope::{ExprScopes, ScopeId},
builtin_type::BuiltinType, builtin_type::BuiltinType,
db::DefDatabase, db::DefDatabase,
expr::{ExprId, LabelId, PatId}, expr::{BindingId, ExprId, LabelId},
generics::{GenericParams, TypeOrConstParamData}, generics::{GenericParams, TypeOrConstParamData},
item_scope::{BuiltinShadowMode, BUILTIN_SCOPE}, item_scope::{BuiltinShadowMode, BUILTIN_SCOPE},
nameres::DefMap, nameres::DefMap,
@ -105,7 +105,7 @@ pub enum ResolveValueResult {
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum ValueNs { pub enum ValueNs {
ImplSelf(ImplId), ImplSelf(ImplId),
LocalBinding(PatId), LocalBinding(BindingId),
FunctionId(FunctionId), FunctionId(FunctionId),
ConstId(ConstId), ConstId(ConstId),
StaticId(StaticId), StaticId(StaticId),
@ -267,7 +267,7 @@ impl Resolver {
if let Some(e) = entry { if let Some(e) = entry {
return Some(ResolveValueResult::ValueNs(ValueNs::LocalBinding( return Some(ResolveValueResult::ValueNs(ValueNs::LocalBinding(
e.pat(), e.binding(),
))); )));
} }
} }
@ -617,7 +617,7 @@ pub enum ScopeDef {
ImplSelfType(ImplId), ImplSelfType(ImplId),
AdtSelfType(AdtId), AdtSelfType(AdtId),
GenericParam(GenericParamId), GenericParam(GenericParamId),
Local(PatId), Local(BindingId),
Label(LabelId), Label(LabelId),
} }
@ -669,7 +669,7 @@ impl Scope {
acc.add(&name, ScopeDef::Label(label)) acc.add(&name, ScopeDef::Label(label))
} }
scope.expr_scopes.entries(scope.scope_id).iter().for_each(|e| { scope.expr_scopes.entries(scope.scope_id).iter().for_each(|e| {
acc.add_local(e.name(), e.pat()); acc.add_local(e.name(), e.binding());
}); });
} }
} }
@ -859,7 +859,7 @@ impl ScopeNames {
self.add(name, ScopeDef::Unknown) self.add(name, ScopeDef::Unknown)
} }
} }
fn add_local(&mut self, name: &Name, pat: PatId) { fn add_local(&mut self, name: &Name, binding: BindingId) {
let set = self.map.entry(name.clone()).or_default(); let set = self.map.entry(name.clone()).or_default();
// XXX: hack, account for local (and only local) shadowing. // XXX: hack, account for local (and only local) shadowing.
// //
@ -870,7 +870,7 @@ impl ScopeNames {
cov_mark::hit!(shadowing_shows_single_completion); cov_mark::hit!(shadowing_shows_single_completion);
return; return;
} }
set.push(ScopeDef::Local(pat)) set.push(ScopeDef::Local(binding))
} }
} }

View file

@ -545,6 +545,49 @@ fn let_else() {
); );
} }
#[test]
fn function_param_patterns() {
check_number(
r#"
const fn f((a, b): &(u8, u8)) -> u8 {
*a + *b
}
const GOAL: u8 = f(&(2, 3));
"#,
5,
);
check_number(
r#"
const fn f(c @ (a, b): &(u8, u8)) -> u8 {
*a + *b + (*c).1
}
const GOAL: u8 = f(&(2, 3));
"#,
8,
);
check_number(
r#"
const fn f(ref a: u8) -> u8 {
*a
}
const GOAL: u8 = f(2);
"#,
2,
);
check_number(
r#"
struct Foo(u8);
impl Foo {
const fn f(&self, (a, b): &(u8, u8)) -> u8 {
self.0 + *a + *b
}
}
const GOAL: u8 = Foo(4).f(&(2, 3));
"#,
9,
);
}
#[test] #[test]
fn options() { fn options() {
check_number( check_number(

View file

@ -235,8 +235,8 @@ impl<'a> DeclValidator<'a> {
let pats_replacements = body let pats_replacements = body
.pats .pats
.iter() .iter()
.filter_map(|(id, pat)| match pat { .filter_map(|(pat_id, pat)| match pat {
Pat::Bind { name, .. } => Some((id, name)), Pat::Bind { id, .. } => Some((pat_id, &body.bindings[*id].name)),
_ => None, _ => None,
}) })
.filter_map(|(id, bind_name)| { .filter_map(|(id, bind_name)| {

View file

@ -146,8 +146,9 @@ impl<'a> PatCtxt<'a> {
PatKind::Leaf { subpatterns } PatKind::Leaf { subpatterns }
} }
hir_def::expr::Pat::Bind { ref name, subpat, .. } => { hir_def::expr::Pat::Bind { id, subpat, .. } => {
let bm = self.infer.pat_binding_modes[&pat]; let bm = self.infer.pat_binding_modes[&pat];
let name = &self.body.bindings[id].name;
match (bm, ty.kind(Interner)) { match (bm, ty.kind(Interner)) {
(BindingMode::Ref(_), TyKind::Ref(.., rty)) => ty = rty, (BindingMode::Ref(_), TyKind::Ref(.., rty)) => ty = rty,
(BindingMode::Ref(_), _) => { (BindingMode::Ref(_), _) => {

View file

@ -22,7 +22,7 @@ use hir_def::{
body::Body, body::Body,
builtin_type::{BuiltinInt, BuiltinType, BuiltinUint}, builtin_type::{BuiltinInt, BuiltinType, BuiltinUint},
data::{ConstData, StaticData}, data::{ConstData, StaticData},
expr::{BindingAnnotation, ExprId, ExprOrPatId, PatId}, expr::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, PatId},
lang_item::{LangItem, LangItemTarget}, lang_item::{LangItem, LangItemTarget},
layout::Integer, layout::Integer,
path::Path, path::Path,
@ -352,6 +352,7 @@ pub struct InferenceResult {
/// **Note**: When a pattern type is resolved it may still contain /// **Note**: When a pattern type is resolved it may still contain
/// unresolved or missing subpatterns or subpatterns of mismatched types. /// unresolved or missing subpatterns or subpatterns of mismatched types.
pub type_of_pat: ArenaMap<PatId, Ty>, pub type_of_pat: ArenaMap<PatId, Ty>,
pub type_of_binding: ArenaMap<BindingId, Ty>,
pub type_of_rpit: ArenaMap<RpitId, Ty>, pub type_of_rpit: ArenaMap<RpitId, Ty>,
type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>, type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>,
/// Interned common types to return references to. /// Interned common types to return references to.
@ -414,6 +415,14 @@ impl Index<PatId> for InferenceResult {
} }
} }
impl Index<BindingId> for InferenceResult {
type Output = Ty;
fn index(&self, b: BindingId) -> &Ty {
self.type_of_binding.get(b).unwrap_or(&self.standard_types.unknown)
}
}
/// The inference context contains all information needed during type inference. /// The inference context contains all information needed during type inference.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct InferenceContext<'a> { pub(crate) struct InferenceContext<'a> {
@ -534,7 +543,10 @@ impl<'a> InferenceContext<'a> {
for ty in result.type_of_pat.values_mut() { for ty in result.type_of_pat.values_mut() {
*ty = table.resolve_completely(ty.clone()); *ty = table.resolve_completely(ty.clone());
} }
for ty in result.type_of_rpit.iter_mut().map(|x| x.1) { for ty in result.type_of_binding.values_mut() {
*ty = table.resolve_completely(ty.clone());
}
for ty in result.type_of_rpit.values_mut() {
*ty = table.resolve_completely(ty.clone()); *ty = table.resolve_completely(ty.clone());
} }
for mismatch in result.type_mismatches.values_mut() { for mismatch in result.type_mismatches.values_mut() {
@ -704,6 +716,10 @@ impl<'a> InferenceContext<'a> {
self.result.type_of_pat.insert(pat, ty); self.result.type_of_pat.insert(pat, ty);
} }
fn write_binding_ty(&mut self, id: BindingId, ty: Ty) {
self.result.type_of_binding.insert(id, ty);
}
fn push_diagnostic(&mut self, diagnostic: InferenceDiagnostic) { fn push_diagnostic(&mut self, diagnostic: InferenceDiagnostic) {
self.result.diagnostics.push(diagnostic); self.result.diagnostics.push(diagnostic);
} }

View file

@ -5,7 +5,7 @@ use std::iter::repeat_with;
use chalk_ir::Mutability; use chalk_ir::Mutability;
use hir_def::{ use hir_def::{
body::Body, body::Body,
expr::{BindingAnnotation, Expr, ExprId, ExprOrPatId, Literal, Pat, PatId, RecordFieldPat}, expr::{Binding, BindingAnnotation, Expr, ExprId, ExprOrPatId, Literal, Pat, PatId, RecordFieldPat, BindingId},
path::Path, path::Path,
}; };
use hir_expand::name::Name; use hir_expand::name::Name;
@ -248,8 +248,8 @@ impl<'a> InferenceContext<'a> {
// FIXME update resolver for the surrounding expression // FIXME update resolver for the surrounding expression
self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty()) self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty())
} }
Pat::Bind { mode, name: _, subpat } => { Pat::Bind { id, subpat } => {
return self.infer_bind_pat(pat, *mode, default_bm, *subpat, &expected); return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected);
} }
Pat::Slice { prefix, slice, suffix } => { Pat::Slice { prefix, slice, suffix } => {
self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm) self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm)
@ -320,11 +320,12 @@ impl<'a> InferenceContext<'a> {
fn infer_bind_pat( fn infer_bind_pat(
&mut self, &mut self,
pat: PatId, pat: PatId,
mode: BindingAnnotation, binding: BindingId,
default_bm: BindingMode, default_bm: BindingMode,
subpat: Option<PatId>, subpat: Option<PatId>,
expected: &Ty, expected: &Ty,
) -> Ty { ) -> Ty {
let Binding { mode, .. } = self.body.bindings[binding];
let mode = if mode == BindingAnnotation::Unannotated { let mode = if mode == BindingAnnotation::Unannotated {
default_bm default_bm
} else { } else {
@ -344,7 +345,8 @@ impl<'a> InferenceContext<'a> {
} }
BindingMode::Move => inner_ty.clone(), BindingMode::Move => inner_ty.clone(),
}; };
self.write_pat_ty(pat, bound_ty); self.write_pat_ty(pat, bound_ty.clone());
self.write_binding_ty(binding, bound_ty);
return inner_ty; return inner_ty;
} }
@ -420,11 +422,14 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool {
Pat::Lit(expr) => { Pat::Lit(expr) => {
!matches!(body[*expr], Expr::Literal(Literal::String(..) | Literal::ByteString(..))) !matches!(body[*expr], Expr::Literal(Literal::String(..) | Literal::ByteString(..)))
} }
Pat::Bind { Pat::Bind { id, subpat: Some(subpat), .. }
mode: BindingAnnotation::Mutable | BindingAnnotation::Unannotated, if matches!(
subpat: Some(subpat), body.bindings[*id].mode,
.. BindingAnnotation::Mutable | BindingAnnotation::Unannotated
} => is_non_ref_pat(body, *subpat), ) =>
{
is_non_ref_pat(body, *subpat)
}
Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => false, Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => false,
} }
} }
@ -432,7 +437,7 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool {
pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool { pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool {
let mut res = false; let mut res = false;
walk_pats(body, pat_id, &mut |pat| { walk_pats(body, pat_id, &mut |pat| {
res |= matches!(pat, Pat::Bind { mode: BindingAnnotation::Ref, .. }) res |= matches!(pat, Pat::Bind { id, .. } if body.bindings[*id].mode == BindingAnnotation::Ref);
}); });
res res
} }

View file

@ -50,7 +50,7 @@ impl<'a> InferenceContext<'a> {
}; };
let typable: ValueTyDefId = match value { let typable: ValueTyDefId = match value {
ValueNs::LocalBinding(pat) => match self.result.type_of_pat.get(pat) { ValueNs::LocalBinding(pat) => match self.result.type_of_binding.get(pat) {
Some(ty) => return Some(ty.clone()), Some(ty) => return Some(ty.clone()),
None => { None => {
never!("uninferred pattern?"); never!("uninferred pattern?");

View file

@ -65,17 +65,9 @@ fn eval_expr(ra_fixture: &str, minicore: &str) -> Result<Layout, LayoutError> {
}) })
.unwrap(); .unwrap();
let hir_body = db.body(adt_id.into()); let hir_body = db.body(adt_id.into());
let pat = hir_body let b = hir_body.bindings.iter().find(|x| x.1.name.to_smol_str() == "goal").unwrap().0;
.pats
.iter()
.find(|x| match x.1 {
hir_def::expr::Pat::Bind { name, .. } => name.to_smol_str() == "goal",
_ => false,
})
.unwrap()
.0;
let infer = db.infer(adt_id.into()); let infer = db.infer(adt_id.into());
let goal_ty = infer.type_of_pat[pat].clone(); let goal_ty = infer.type_of_binding[b].clone();
layout_of_ty(&db, &goal_ty, module_id.krate()) layout_of_ty(&db, &goal_ty, module_id.krate())
} }

View file

@ -6,7 +6,8 @@ use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind};
use hir_def::{ use hir_def::{
body::Body, body::Body,
expr::{ expr::{
Array, BindingAnnotation, ExprId, LabelId, Literal, MatchArm, Pat, PatId, RecordLitField, Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId,
RecordLitField,
}, },
layout::LayoutError, layout::LayoutError,
resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
@ -30,7 +31,7 @@ struct LoopBlocks {
struct MirLowerCtx<'a> { struct MirLowerCtx<'a> {
result: MirBody, result: MirBody,
owner: DefWithBodyId, owner: DefWithBodyId,
binding_locals: ArenaMap<PatId, LocalId>, binding_locals: ArenaMap<BindingId, LocalId>,
current_loop_blocks: Option<LoopBlocks>, current_loop_blocks: Option<LoopBlocks>,
discr_temp: Option<Place>, discr_temp: Option<Place>,
db: &'a dyn HirDatabase, db: &'a dyn HirDatabase,
@ -43,7 +44,9 @@ pub enum MirLowerError {
ConstEvalError(Box<ConstEvalError>), ConstEvalError(Box<ConstEvalError>),
LayoutError(LayoutError), LayoutError(LayoutError),
IncompleteExpr, IncompleteExpr,
UnresolvedName, UnresolvedName(String),
UnresolvedMethod,
UnresolvedField,
MissingFunctionDefinition, MissingFunctionDefinition,
TypeError(&'static str), TypeError(&'static str),
NotSupported(String), NotSupported(String),
@ -222,22 +225,23 @@ impl MirLowerCtx<'_> {
match &self.body.exprs[expr_id] { match &self.body.exprs[expr_id] {
Expr::Missing => Err(MirLowerError::IncompleteExpr), Expr::Missing => Err(MirLowerError::IncompleteExpr),
Expr::Path(p) => { Expr::Path(p) => {
let unresolved_name = || MirLowerError::UnresolvedName("".to_string());
let resolver = resolver_for_expr(self.db.upcast(), self.owner, expr_id); let resolver = resolver_for_expr(self.db.upcast(), self.owner, expr_id);
let pr = resolver let pr = resolver
.resolve_path_in_value_ns(self.db.upcast(), p.mod_path()) .resolve_path_in_value_ns(self.db.upcast(), p.mod_path())
.ok_or(MirLowerError::UnresolvedName)?; .ok_or_else(unresolved_name)?;
let pr = match pr { let pr = match pr {
ResolveValueResult::ValueNs(v) => v, ResolveValueResult::ValueNs(v) => v,
ResolveValueResult::Partial(..) => { ResolveValueResult::Partial(..) => {
return match self return match self
.infer .infer
.assoc_resolutions_for_expr(expr_id) .assoc_resolutions_for_expr(expr_id)
.ok_or(MirLowerError::UnresolvedName)? .ok_or_else(unresolved_name)?
.0 .0
//.ok_or(ConstEvalError::SemanticError("unresolved assoc item"))? //.ok_or(ConstEvalError::SemanticError("unresolved assoc item"))?
{ {
hir_def::AssocItemId::ConstId(c) => self.lower_const(c, current, place), hir_def::AssocItemId::ConstId(c) => self.lower_const(c, current, place),
_ => return Err(MirLowerError::UnresolvedName), _ => return Err(unresolved_name()),
}; };
} }
}; };
@ -394,7 +398,7 @@ impl MirLowerCtx<'_> {
} }
Expr::MethodCall { receiver, args, .. } => { Expr::MethodCall { receiver, args, .. } => {
let (func_id, generic_args) = let (func_id, generic_args) =
self.infer.method_resolution(expr_id).ok_or(MirLowerError::UnresolvedName)?; self.infer.method_resolution(expr_id).ok_or(MirLowerError::UnresolvedMethod)?;
let ty = chalk_ir::TyKind::FnDef( let ty = chalk_ir::TyKind::FnDef(
CallableDefId::FunctionId(func_id).to_chalk(self.db), CallableDefId::FunctionId(func_id).to_chalk(self.db),
generic_args, generic_args,
@ -476,7 +480,7 @@ impl MirLowerCtx<'_> {
let variant_id = self let variant_id = self
.infer .infer
.variant_resolution_for_expr(expr_id) .variant_resolution_for_expr(expr_id)
.ok_or(MirLowerError::UnresolvedName)?; .ok_or_else(|| MirLowerError::UnresolvedName("".to_string()))?;
let subst = match self.expr_ty(expr_id).kind(Interner) { let subst = match self.expr_ty(expr_id).kind(Interner) {
TyKind::Adt(_, s) => s.clone(), TyKind::Adt(_, s) => s.clone(),
_ => not_supported!("Non ADT record literal"), _ => not_supported!("Non ADT record literal"),
@ -487,7 +491,7 @@ impl MirLowerCtx<'_> {
let mut operands = vec![None; variant_data.fields().len()]; let mut operands = vec![None; variant_data.fields().len()];
for RecordLitField { name, expr } in fields.iter() { for RecordLitField { name, expr } in fields.iter() {
let field_id = let field_id =
variant_data.field(name).ok_or(MirLowerError::UnresolvedName)?; variant_data.field(name).ok_or(MirLowerError::UnresolvedField)?;
let op; let op;
(op, current) = self.lower_expr_to_some_operand(*expr, current)?; (op, current) = self.lower_expr_to_some_operand(*expr, current)?;
operands[u32::from(field_id.into_raw()) as usize] = Some(op); operands[u32::from(field_id.into_raw()) as usize] = Some(op);
@ -509,7 +513,7 @@ impl MirLowerCtx<'_> {
not_supported!("Union record literal with more than one field"); not_supported!("Union record literal with more than one field");
}; };
let local_id = let local_id =
variant_data.field(name).ok_or(MirLowerError::UnresolvedName)?; variant_data.field(name).ok_or(MirLowerError::UnresolvedField)?;
let mut place = place; let mut place = place;
place place
.projection .projection
@ -529,7 +533,7 @@ impl MirLowerCtx<'_> {
let field = self let field = self
.infer .infer
.field_resolution(expr_id) .field_resolution(expr_id)
.ok_or(MirLowerError::UnresolvedName)?; .ok_or(MirLowerError::UnresolvedField)?;
current_place.projection.push(ProjectionElem::Field(field)); current_place.projection.push(ProjectionElem::Field(field));
} }
self.push_assignment(current, place, Operand::Copy(current_place).into()); self.push_assignment(current, place, Operand::Copy(current_place).into());
@ -962,8 +966,9 @@ impl MirLowerCtx<'_> {
} }
(then_target, Some(else_target)) (then_target, Some(else_target))
} }
Pat::Bind { mode, name: _, subpat } => { Pat::Bind { id, subpat } => {
let target_place = self.binding_locals[pattern]; let target_place = self.binding_locals[*id];
let mode = self.body.bindings[*id].mode;
if let Some(subpat) = subpat { if let Some(subpat) = subpat {
(current, current_else) = self.pattern_match( (current, current_else) = self.pattern_match(
current, current,
@ -975,7 +980,7 @@ impl MirLowerCtx<'_> {
)? )?
} }
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) {
binding_mode = *mode; binding_mode = mode;
} }
self.push_assignment( self.push_assignment(
current, current,
@ -1189,17 +1194,40 @@ pub fn lower_to_mir(
let mut locals = Arena::new(); let mut locals = Arena::new();
// 0 is return local // 0 is return local
locals.alloc(Local { mutability: Mutability::Mut, ty: infer[root_expr].clone() }); locals.alloc(Local { mutability: Mutability::Mut, ty: infer[root_expr].clone() });
let mut create_local_of_path = |p: PatId| { let mut binding_locals: ArenaMap<BindingId, LocalId> = ArenaMap::new();
// FIXME: mutablity is broken let param_locals: ArenaMap<PatId, LocalId> = if let DefWithBodyId::FunctionId(fid) = owner {
locals.alloc(Local { mutability: Mutability::Not, ty: infer[p].clone() }) let substs = TyBuilder::placeholder_subst(db, fid);
}; let callable_sig = db.callable_item_signature(fid.into()).substitute(Interner, &substs);
// 1 to param_len is for params // 1 to param_len is for params
let mut binding_locals: ArenaMap<PatId, LocalId> = body.params
body.params.iter().map(|&x| (x, create_local_of_path(x))).collect(); .iter()
.zip(callable_sig.params().iter())
.map(|(&x, ty)| {
let local_id = locals.alloc(Local { mutability: Mutability::Not, ty: ty.clone() });
if let Pat::Bind { id, subpat: None } = body[x] {
if matches!(
body.bindings[id].mode,
BindingAnnotation::Unannotated | BindingAnnotation::Mutable
) {
binding_locals.insert(id, local_id);
}
}
(x, local_id)
})
.collect()
} else {
if !body.params.is_empty() {
return Err(MirLowerError::TypeError("Unexpected parameter for non function body"));
}
ArenaMap::new()
};
// and then rest of bindings // and then rest of bindings
for (pat_id, _) in body.pats.iter() { for (id, _) in body.bindings.iter() {
if !binding_locals.contains_idx(pat_id) { if !binding_locals.contains_idx(id) {
binding_locals.insert(pat_id, create_local_of_path(pat_id)); binding_locals.insert(
id,
locals.alloc(Local { mutability: Mutability::Not, ty: infer[id].clone() }),
);
} }
} }
let mir = MirBody { basic_blocks, locals, start_block, owner, arg_count: body.params.len() }; let mir = MirBody { basic_blocks, locals, start_block, owner, arg_count: body.params.len() };
@ -1213,7 +1241,27 @@ pub fn lower_to_mir(
current_loop_blocks: None, current_loop_blocks: None,
discr_temp: None, discr_temp: None,
}; };
let b = ctx.lower_expr_to_place(root_expr, return_slot().into(), start_block)?; let mut current = start_block;
for &param in &body.params {
if let Pat::Bind { id, .. } = body[param] {
if param_locals[param] == ctx.binding_locals[id] {
continue;
}
}
let r = ctx.pattern_match(
current,
None,
param_locals[param].into(),
ctx.result.locals[param_locals[param]].ty.clone(),
param,
BindingAnnotation::Unannotated,
)?;
if let Some(b) = r.1 {
ctx.set_terminator(b, Terminator::Unreachable);
}
current = r.0;
}
let b = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)?;
ctx.result.basic_blocks[b].terminator = Some(Terminator::Return); ctx.result.basic_blocks[b].terminator = Some(Terminator::Return);
Ok(ctx.result) Ok(ctx.result)
} }

View file

@ -4,7 +4,7 @@
//! are splitting the hir. //! are splitting the hir.
use hir_def::{ use hir_def::{
expr::{LabelId, PatId}, expr::{BindingId, LabelId},
AdtId, AssocItemId, DefWithBodyId, EnumVariantId, FieldId, GenericDefId, GenericParamId, AdtId, AssocItemId, DefWithBodyId, EnumVariantId, FieldId, GenericDefId, GenericParamId,
ModuleDefId, VariantId, ModuleDefId, VariantId,
}; };
@ -251,9 +251,9 @@ impl From<AssocItem> for GenericDefId {
} }
} }
impl From<(DefWithBodyId, PatId)> for Local { impl From<(DefWithBodyId, BindingId)> for Local {
fn from((parent, pat_id): (DefWithBodyId, PatId)) -> Self { fn from((parent, binding_id): (DefWithBodyId, BindingId)) -> Self {
Local { parent, pat_id } Local { parent, binding_id }
} }
} }

View file

@ -41,7 +41,7 @@ use either::Either;
use hir_def::{ use hir_def::{
adt::VariantData, adt::VariantData,
body::{BodyDiagnostic, SyntheticSyntax}, body::{BodyDiagnostic, SyntheticSyntax},
expr::{BindingAnnotation, ExprOrPatId, LabelId, Pat, PatId}, expr::{BindingAnnotation, BindingId, ExprOrPatId, LabelId, Pat},
generics::{LifetimeParamData, TypeOrConstParamData, TypeParamProvenance}, generics::{LifetimeParamData, TypeOrConstParamData, TypeParamProvenance},
item_tree::ItemTreeNode, item_tree::ItemTreeNode,
lang_item::{LangItem, LangItemTarget}, lang_item::{LangItem, LangItemTarget},
@ -77,7 +77,7 @@ use rustc_hash::FxHashSet;
use stdx::{impl_from, never}; use stdx::{impl_from, never};
use syntax::{ use syntax::{
ast::{self, HasAttrs as _, HasDocComments, HasName}, ast::{self, HasAttrs as _, HasDocComments, HasName},
AstNode, AstPtr, SmolStr, SyntaxNodePtr, TextRange, T, AstNode, AstPtr, SmolStr, SyntaxNode, SyntaxNodePtr, TextRange, T,
}; };
use crate::db::{DefDatabase, HirDatabase}; use crate::db::{DefDatabase, HirDatabase};
@ -1782,8 +1782,8 @@ impl Param {
let parent = DefWithBodyId::FunctionId(self.func.into()); let parent = DefWithBodyId::FunctionId(self.func.into());
let body = db.body(parent); let body = db.body(parent);
let pat_id = body.params[self.idx]; let pat_id = body.params[self.idx];
if let Pat::Bind { .. } = &body[pat_id] { if let Pat::Bind { id, .. } = &body[pat_id] {
Some(Local { parent, pat_id: body.params[self.idx] }) Some(Local { parent, binding_id: *id })
} else { } else {
None None
} }
@ -2460,13 +2460,42 @@ impl GenericDef {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Local { pub struct Local {
pub(crate) parent: DefWithBodyId, pub(crate) parent: DefWithBodyId,
pub(crate) pat_id: PatId, pub(crate) binding_id: BindingId,
}
pub struct LocalSource {
pub local: Local,
pub source: InFile<Either<ast::IdentPat, ast::SelfParam>>,
}
impl LocalSource {
pub fn as_ident_pat(&self) -> Option<&ast::IdentPat> {
match &self.source.value {
Either::Left(x) => Some(x),
Either::Right(_) => None,
}
}
pub fn into_ident_pat(self) -> Option<ast::IdentPat> {
match self.source.value {
Either::Left(x) => Some(x),
Either::Right(_) => None,
}
}
pub fn original_file(&self, db: &dyn HirDatabase) -> FileId {
self.source.file_id.original_file(db.upcast())
}
pub fn syntax(&self) -> &SyntaxNode {
self.source.value.syntax()
}
} }
impl Local { impl Local {
pub fn is_param(self, db: &dyn HirDatabase) -> bool { pub fn is_param(self, db: &dyn HirDatabase) -> bool {
let src = self.source(db); let src = self.primary_source(db);
match src.value { match src.source.value {
Either::Left(pat) => pat Either::Left(pat) => pat
.syntax() .syntax()
.ancestors() .ancestors()
@ -2486,13 +2515,7 @@ impl Local {
pub fn name(self, db: &dyn HirDatabase) -> Name { pub fn name(self, db: &dyn HirDatabase) -> Name {
let body = db.body(self.parent); let body = db.body(self.parent);
match &body[self.pat_id] { body[self.binding_id].name.clone()
Pat::Bind { name, .. } => name.clone(),
_ => {
stdx::never!("hir::Local is missing a name!");
Name::missing()
}
}
} }
pub fn is_self(self, db: &dyn HirDatabase) -> bool { pub fn is_self(self, db: &dyn HirDatabase) -> bool {
@ -2501,15 +2524,12 @@ impl Local {
pub fn is_mut(self, db: &dyn HirDatabase) -> bool { pub fn is_mut(self, db: &dyn HirDatabase) -> bool {
let body = db.body(self.parent); let body = db.body(self.parent);
matches!(&body[self.pat_id], Pat::Bind { mode: BindingAnnotation::Mutable, .. }) body[self.binding_id].mode == BindingAnnotation::Mutable
} }
pub fn is_ref(self, db: &dyn HirDatabase) -> bool { pub fn is_ref(self, db: &dyn HirDatabase) -> bool {
let body = db.body(self.parent); let body = db.body(self.parent);
matches!( matches!(body[self.binding_id].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut)
&body[self.pat_id],
Pat::Bind { mode: BindingAnnotation::Ref | BindingAnnotation::RefMut, .. }
)
} }
pub fn parent(self, _db: &dyn HirDatabase) -> DefWithBody { pub fn parent(self, _db: &dyn HirDatabase) -> DefWithBody {
@ -2523,34 +2543,33 @@ impl Local {
pub fn ty(self, db: &dyn HirDatabase) -> Type { pub fn ty(self, db: &dyn HirDatabase) -> Type {
let def = self.parent; let def = self.parent;
let infer = db.infer(def); let infer = db.infer(def);
let ty = infer[self.pat_id].clone(); let ty = infer[self.binding_id].clone();
Type::new(db, def, ty) Type::new(db, def, ty)
} }
pub fn associated_locals(self, db: &dyn HirDatabase) -> Box<[Local]> { /// All definitions for this local. Example: `let (a$0, _) | (_, a$0) = x;`
let body = db.body(self.parent); pub fn sources(self, db: &dyn HirDatabase) -> Vec<LocalSource> {
body.ident_patterns_for(&self.pat_id) let (body, source_map) = db.body_with_source_map(self.parent);
body[self.binding_id]
.definitions
.iter() .iter()
.map(|&pat_id| Local { parent: self.parent, pat_id }) .map(|&definition| {
.collect() let src = source_map.pat_syntax(definition).unwrap(); // Hmm...
}
/// If this local is part of a multi-local, retrieve the representative local.
/// That is the local that references are being resolved to.
pub fn representative(self, db: &dyn HirDatabase) -> Local {
let body = db.body(self.parent);
Local { pat_id: body.pattern_representative(self.pat_id), ..self }
}
pub fn source(self, db: &dyn HirDatabase) -> InFile<Either<ast::IdentPat, ast::SelfParam>> {
let (_body, source_map) = db.body_with_source_map(self.parent);
let src = source_map.pat_syntax(self.pat_id).unwrap(); // Hmm...
let root = src.file_syntax(db.upcast()); let root = src.file_syntax(db.upcast());
src.map(|ast| match ast { src.map(|ast| match ast {
// Suspicious unwrap // Suspicious unwrap
Either::Left(it) => Either::Left(it.cast().unwrap().to_node(&root)), Either::Left(it) => Either::Left(it.cast().unwrap().to_node(&root)),
Either::Right(it) => Either::Right(it.to_node(&root)), Either::Right(it) => Either::Right(it.to_node(&root)),
}) })
})
.map(|source| LocalSource { local: self, source })
.collect()
}
/// The leftmost definition for this local. Example: `let (a$0, _) | (_, a) = x;`
pub fn primary_source(self, db: &dyn HirDatabase) -> LocalSource {
let all_sources = self.sources(db);
all_sources.into_iter().next().unwrap()
} }
} }

View file

@ -1654,8 +1654,8 @@ impl<'a> SemanticsScope<'a> {
resolver::ScopeDef::ImplSelfType(it) => ScopeDef::ImplSelfType(it.into()), resolver::ScopeDef::ImplSelfType(it) => ScopeDef::ImplSelfType(it.into()),
resolver::ScopeDef::AdtSelfType(it) => ScopeDef::AdtSelfType(it.into()), resolver::ScopeDef::AdtSelfType(it) => ScopeDef::AdtSelfType(it.into()),
resolver::ScopeDef::GenericParam(id) => ScopeDef::GenericParam(id.into()), resolver::ScopeDef::GenericParam(id) => ScopeDef::GenericParam(id.into()),
resolver::ScopeDef::Local(pat_id) => match self.resolver.body_owner() { resolver::ScopeDef::Local(binding_id) => match self.resolver.body_owner() {
Some(parent) => ScopeDef::Local(Local { parent, pat_id }), Some(parent) => ScopeDef::Local(Local { parent, binding_id }),
None => continue, None => continue,
}, },
resolver::ScopeDef::Label(label_id) => match self.resolver.body_owner() { resolver::ScopeDef::Label(label_id) => match self.resolver.body_owner() {

View file

@ -89,7 +89,7 @@ use base_db::FileId;
use hir_def::{ use hir_def::{
child_by_source::ChildBySource, child_by_source::ChildBySource,
dyn_map::DynMap, dyn_map::DynMap,
expr::{LabelId, PatId}, expr::{BindingId, LabelId},
keys::{self, Key}, keys::{self, Key},
AdtId, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, FieldId, FunctionId, AdtId, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, FieldId, FunctionId,
GenericDefId, GenericParamId, ImplId, LifetimeParamId, MacroId, ModuleId, StaticId, StructId, GenericDefId, GenericParamId, ImplId, LifetimeParamId, MacroId, ModuleId, StaticId, StructId,
@ -98,7 +98,7 @@ use hir_def::{
use hir_expand::{attrs::AttrId, name::AsName, HirFileId, MacroCallId}; use hir_expand::{attrs::AttrId, name::AsName, HirFileId, MacroCallId};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use smallvec::SmallVec; use smallvec::SmallVec;
use stdx::impl_from; use stdx::{impl_from, never};
use syntax::{ use syntax::{
ast::{self, HasName}, ast::{self, HasName},
AstNode, SyntaxNode, AstNode, SyntaxNode,
@ -216,14 +216,14 @@ impl SourceToDefCtx<'_, '_> {
pub(super) fn bind_pat_to_def( pub(super) fn bind_pat_to_def(
&mut self, &mut self,
src: InFile<ast::IdentPat>, src: InFile<ast::IdentPat>,
) -> Option<(DefWithBodyId, PatId)> { ) -> Option<(DefWithBodyId, BindingId)> {
let container = self.find_pat_or_label_container(src.syntax())?; let container = self.find_pat_or_label_container(src.syntax())?;
let (body, source_map) = self.db.body_with_source_map(container); let (body, source_map) = self.db.body_with_source_map(container);
let src = src.map(ast::Pat::from); let src = src.map(ast::Pat::from);
let pat_id = source_map.node_pat(src.as_ref())?; let pat_id = source_map.node_pat(src.as_ref())?;
// the pattern could resolve to a constant, verify that that is not the case // the pattern could resolve to a constant, verify that that is not the case
if let crate::Pat::Bind { .. } = body[pat_id] { if let crate::Pat::Bind { id, .. } = body[pat_id] {
Some((container, pat_id)) Some((container, id))
} else { } else {
None None
} }
@ -231,11 +231,16 @@ impl SourceToDefCtx<'_, '_> {
pub(super) fn self_param_to_def( pub(super) fn self_param_to_def(
&mut self, &mut self,
src: InFile<ast::SelfParam>, src: InFile<ast::SelfParam>,
) -> Option<(DefWithBodyId, PatId)> { ) -> Option<(DefWithBodyId, BindingId)> {
let container = self.find_pat_or_label_container(src.syntax())?; let container = self.find_pat_or_label_container(src.syntax())?;
let (_body, source_map) = self.db.body_with_source_map(container); let (body, source_map) = self.db.body_with_source_map(container);
let pat_id = source_map.node_self_param(src.as_ref())?; let pat_id = source_map.node_self_param(src.as_ref())?;
Some((container, pat_id)) if let crate::Pat::Bind { id, .. } = body[pat_id] {
Some((container, id))
} else {
never!();
None
}
} }
pub(super) fn label_to_def( pub(super) fn label_to_def(
&mut self, &mut self,

View file

@ -422,8 +422,8 @@ impl SourceAnalyzer {
// Shorthand syntax, resolve to the local // Shorthand syntax, resolve to the local
let path = ModPath::from_segments(PathKind::Plain, once(local_name.clone())); let path = ModPath::from_segments(PathKind::Plain, once(local_name.clone()));
match self.resolver.resolve_path_in_value_ns_fully(db.upcast(), &path) { match self.resolver.resolve_path_in_value_ns_fully(db.upcast(), &path) {
Some(ValueNs::LocalBinding(pat_id)) => { Some(ValueNs::LocalBinding(binding_id)) => {
Some(Local { pat_id, parent: self.resolver.body_owner()? }) Some(Local { binding_id, parent: self.resolver.body_owner()? })
} }
_ => None, _ => None,
} }
@ -1018,8 +1018,8 @@ fn resolve_hir_path_(
let values = || { let values = || {
resolver.resolve_path_in_value_ns_fully(db.upcast(), path.mod_path()).and_then(|val| { resolver.resolve_path_in_value_ns_fully(db.upcast(), path.mod_path()).and_then(|val| {
let res = match val { let res = match val {
ValueNs::LocalBinding(pat_id) => { ValueNs::LocalBinding(binding_id) => {
let var = Local { parent: body_owner?, pat_id }; let var = Local { parent: body_owner?, binding_id };
PathResolution::Local(var) PathResolution::Local(var)
} }
ValueNs::FunctionId(it) => PathResolution::Def(Function::from(it).into()), ValueNs::FunctionId(it) => PathResolution::Def(Function::from(it).into()),

View file

@ -101,7 +101,7 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?; let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
match NameRefClass::classify(&ctx.sema, &name_ref)? { match NameRefClass::classify(&ctx.sema, &name_ref)? {
NameRefClass::Definition(Definition::Local(local)) => { NameRefClass::Definition(Definition::Local(local)) => {
let source = local.source(ctx.db()).value.left()?; let source = local.primary_source(ctx.db()).into_ident_pat()?;
Some(source.name()?) Some(source.name()?)
} }
_ => None, _ => None,

View file

@ -3,7 +3,8 @@ use std::iter;
use ast::make; use ast::make;
use either::Either; use either::Either;
use hir::{ use hir::{
HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam, HasSource, HirDisplay, InFile, Local, LocalSource, ModuleDef, PathResolution, Semantics,
TypeInfo, TypeParam,
}; };
use ide_db::{ use ide_db::{
defs::{Definition, NameRefClass}, defs::{Definition, NameRefClass},
@ -710,7 +711,7 @@ impl FunctionBody {
) => local_ref, ) => local_ref,
_ => return, _ => return,
}; };
let InFile { file_id, value } = local_ref.source(sema.db); let InFile { file_id, value } = local_ref.primary_source(sema.db).source;
// locals defined inside macros are not relevant to us // locals defined inside macros are not relevant to us
if !file_id.is_macro() { if !file_id.is_macro() {
match value { match value {
@ -972,11 +973,11 @@ impl FunctionBody {
locals: impl Iterator<Item = Local>, locals: impl Iterator<Item = Local>,
) -> Vec<Param> { ) -> Vec<Param> {
locals locals
.map(|local| (local, local.source(ctx.db()))) .map(|local| (local, local.primary_source(ctx.db())))
.filter(|(_, src)| is_defined_outside_of_body(ctx, self, src)) .filter(|(_, src)| is_defined_outside_of_body(ctx, self, src))
.filter_map(|(local, src)| match src.value { .filter_map(|(local, src)| match src.into_ident_pat() {
Either::Left(src) => Some((local, src)), Some(src) => Some((local, src)),
Either::Right(_) => { None => {
stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); stdx::never!(false, "Local::is_self returned false, but source is SelfParam");
None None
} }
@ -1238,17 +1239,9 @@ fn local_outlives_body(
fn is_defined_outside_of_body( fn is_defined_outside_of_body(
ctx: &AssistContext<'_>, ctx: &AssistContext<'_>,
body: &FunctionBody, body: &FunctionBody,
src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>, src: &LocalSource,
) -> bool { ) -> bool {
src.file_id.original_file(ctx.db()) == ctx.file_id() src.original_file(ctx.db()) == ctx.file_id() && !body.contains_node(src.syntax())
&& !body.contains_node(either_syntax(&src.value))
}
fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
match value {
Either::Left(pat) => pat.syntax(),
Either::Right(it) => it.syntax(),
}
} }
/// find where to put extracted function definition /// find where to put extracted function definition

View file

@ -1,4 +1,3 @@
use either::Either;
use hir::{PathResolution, Semantics}; use hir::{PathResolution, Semantics};
use ide_db::{ use ide_db::{
base_db::FileId, base_db::FileId,
@ -205,12 +204,14 @@ fn inline_usage(
return None; return None;
} }
// FIXME: Handle multiple local definitions let sources = local.sources(sema.db);
let bind_pat = match local.source(sema.db).value { let [source] = sources.as_slice() else {
Either::Left(ident) => ident, // Not applicable with locals with multiple definitions (i.e. or patterns)
_ => return None, return None;
}; };
let bind_pat = source.as_ident_pat()?;
let let_stmt = ast::LetStmt::cast(bind_pat.syntax().parent()?)?; let let_stmt = ast::LetStmt::cast(bind_pat.syntax().parent()?)?;
let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all(); let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all();

View file

@ -121,14 +121,8 @@ impl Definition {
Definition::Trait(it) => name_range(it, sema), Definition::Trait(it) => name_range(it, sema),
Definition::TraitAlias(it) => name_range(it, sema), Definition::TraitAlias(it) => name_range(it, sema),
Definition::TypeAlias(it) => name_range(it, sema), Definition::TypeAlias(it) => name_range(it, sema),
Definition::Local(local) => { // A local might be `self` or have multiple definitons like `let (a | a) = 2`, so it should be handled as a special case
let src = local.source(sema.db); Definition::Local(_) => return None,
let name = match &src.value {
Either::Left(bind_pat) => bind_pat.name()?,
Either::Right(_) => return None,
};
src.with_value(name.syntax()).original_file_range_opt(sema.db)
}
Definition::GenericParam(generic_param) => match generic_param { Definition::GenericParam(generic_param) => match generic_param {
hir::GenericParam::LifetimeParam(lifetime_param) => { hir::GenericParam::LifetimeParam(lifetime_param) => {
let src = lifetime_param.source(sema.db)?; let src = lifetime_param.source(sema.db)?;
@ -302,13 +296,7 @@ fn rename_reference(
source_change.insert_source_edit(file_id, edit); source_change.insert_source_edit(file_id, edit);
Ok(()) Ok(())
}; };
match def { insert_def_edit(def)?;
Definition::Local(l) => l
.associated_locals(sema.db)
.iter()
.try_for_each(|&local| insert_def_edit(Definition::Local(local))),
def => insert_def_edit(def),
}?;
Ok(source_change) Ok(source_change)
} }
@ -471,16 +459,16 @@ fn source_edit_from_def(
def: Definition, def: Definition,
new_name: &str, new_name: &str,
) -> Result<(FileId, TextEdit)> { ) -> Result<(FileId, TextEdit)> {
let FileRange { file_id, range } = def
.range_for_rename(sema)
.ok_or_else(|| format_err!("No identifier available to rename"))?;
let mut edit = TextEdit::builder(); let mut edit = TextEdit::builder();
if let Definition::Local(local) = def { if let Definition::Local(local) = def {
if let Either::Left(pat) = local.source(sema.db).value { let mut file_id = None;
for source in local.sources(sema.db) {
let source = source.source;
file_id = source.file_id.file_id();
if let Either::Left(pat) = source.value {
let name_range = pat.name().unwrap().syntax().text_range();
// special cases required for renaming fields/locals in Record patterns // special cases required for renaming fields/locals in Record patterns
if let Some(pat_field) = pat.syntax().parent().and_then(ast::RecordPatField::cast) { if let Some(pat_field) = pat.syntax().parent().and_then(ast::RecordPatField::cast) {
let name_range = pat.name().unwrap().syntax().text_range();
if let Some(name_ref) = pat_field.name_ref() { if let Some(name_ref) = pat_field.name_ref() {
if new_name == name_ref.text() && pat.at_token().is_none() { if new_name == name_ref.text() && pat.at_token().is_none() {
// Foo { field: ref mut local } -> Foo { ref mut field } // Foo { field: ref mut local } -> Foo { ref mut field }
@ -510,20 +498,25 @@ fn source_edit_from_def(
); );
edit.replace(name_range, new_name.to_string()); edit.replace(name_range, new_name.to_string());
} }
} else {
edit.replace(name_range, new_name.to_string());
} }
} }
} }
if edit.is_empty() { let Some(file_id) = file_id else { bail!("No file available to rename") };
return Ok((file_id, edit.finish()));
}
let FileRange { file_id, range } = def
.range_for_rename(sema)
.ok_or_else(|| format_err!("No identifier available to rename"))?;
let (range, new_name) = match def { let (range, new_name) = match def {
Definition::GenericParam(hir::GenericParam::LifetimeParam(_)) Definition::GenericParam(hir::GenericParam::LifetimeParam(_)) | Definition::Label(_) => (
| Definition::Label(_) => (
TextRange::new(range.start() + syntax::TextSize::from(1), range.end()), TextRange::new(range.start() + syntax::TextSize::from(1), range.end()),
new_name.strip_prefix('\'').unwrap_or(new_name).to_owned(), new_name.strip_prefix('\'').unwrap_or(new_name).to_owned(),
), ),
_ => (range, new_name.to_owned()), _ => (range, new_name.to_owned()),
}; };
edit.replace(range, new_name); edit.replace(range, new_name);
}
Ok((file_id, edit.finish())) Ok((file_id, edit.finish()))
} }

View file

@ -320,7 +320,7 @@ impl Definition {
scope: None, scope: None,
include_self_kw_refs: None, include_self_kw_refs: None,
local_repr: match self { local_repr: match self {
Definition::Local(local) => Some(local.representative(sema.db)), Definition::Local(local) => Some(local),
_ => None, _ => None,
}, },
search_self_mod: false, search_self_mod: false,
@ -646,7 +646,7 @@ impl<'a> FindUsages<'a> {
match NameRefClass::classify(self.sema, name_ref) { match NameRefClass::classify(self.sema, name_ref) {
Some(NameRefClass::Definition(def @ Definition::Local(local))) Some(NameRefClass::Definition(def @ Definition::Local(local)))
if matches!( if matches!(
self.local_repr, Some(repr) if repr == local.representative(self.sema.db) self.local_repr, Some(repr) if repr == local
) => ) =>
{ {
let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax()); let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax());
@ -707,7 +707,7 @@ impl<'a> FindUsages<'a> {
Definition::Field(_) if field == self.def => { Definition::Field(_) if field == self.def => {
ReferenceCategory::new(&field, name_ref) ReferenceCategory::new(&field, name_ref)
} }
Definition::Local(_) if matches!(self.local_repr, Some(repr) if repr == local.representative(self.sema.db)) => { Definition::Local(_) if matches!(self.local_repr, Some(repr) if repr == local) => {
ReferenceCategory::new(&Definition::Local(local), name_ref) ReferenceCategory::new(&Definition::Local(local), name_ref)
} }
_ => return false, _ => return false,
@ -755,7 +755,7 @@ impl<'a> FindUsages<'a> {
Some(NameClass::Definition(def @ Definition::Local(local))) if def != self.def => { Some(NameClass::Definition(def @ Definition::Local(local))) if def != self.def => {
if matches!( if matches!(
self.local_repr, self.local_repr,
Some(repr) if local.representative(self.sema.db) == repr Some(repr) if local == repr
) { ) {
let FileRange { file_id, range } = self.sema.original_range(name.syntax()); let FileRange { file_id, range } = self.sema.original_range(name.syntax());
let reference = FileReference { let reference = FileReference {

View file

@ -14,7 +14,7 @@ use syntax::{
SyntaxNode, SyntaxToken, TextRange, T, SyntaxNode, SyntaxToken, TextRange, T,
}; };
use crate::{references, NavigationTarget, TryToNav}; use crate::{navigation_target::ToNav, references, NavigationTarget, TryToNav};
#[derive(PartialEq, Eq, Hash)] #[derive(PartialEq, Eq, Hash)]
pub struct HighlightedRange { pub struct HighlightedRange {
@ -98,8 +98,22 @@ fn highlight_references(
category: access, category: access,
}); });
let mut res = FxHashSet::default(); let mut res = FxHashSet::default();
for &def in &defs {
let mut def_to_hl_range = |def| { match def {
Definition::Local(local) => {
let category = local.is_mut(sema.db).then_some(ReferenceCategory::Write);
local
.sources(sema.db)
.into_iter()
.map(|x| x.to_nav(sema.db))
.filter(|decl| decl.file_id == file_id)
.filter_map(|decl| decl.focus_range)
.map(|range| HighlightedRange { range, category })
.for_each(|x| {
res.insert(x);
});
}
def => {
let hl_range = match def { let hl_range = match def {
Definition::Module(module) => { Definition::Module(module) => {
Some(NavigationTarget::from_module_to_decl(sema.db, module)) Some(NavigationTarget::from_module_to_decl(sema.db, module))
@ -109,21 +123,14 @@ fn highlight_references(
.filter(|decl| decl.file_id == file_id) .filter(|decl| decl.file_id == file_id)
.and_then(|decl| decl.focus_range) .and_then(|decl| decl.focus_range)
.map(|range| { .map(|range| {
let category = let category = references::decl_mutability(&def, node, range)
references::decl_mutability(&def, node, range).then_some(ReferenceCategory::Write); .then_some(ReferenceCategory::Write);
HighlightedRange { range, category } HighlightedRange { range, category }
}); });
if let Some(hl_range) = hl_range { if let Some(hl_range) = hl_range {
res.insert(hl_range); res.insert(hl_range);
} }
}; }
for &def in &defs {
match def {
Definition::Local(local) => local
.associated_locals(sema.db)
.iter()
.for_each(|&local| def_to_hl_range(Definition::Local(local))),
def => def_to_hl_range(def),
} }
} }

View file

@ -635,8 +635,8 @@ fn local(db: &RootDatabase, it: hir::Local) -> Option<Markup> {
let ty = it.ty(db); let ty = it.ty(db);
let ty = ty.display_truncated(db, None); let ty = ty.display_truncated(db, None);
let is_mut = if it.is_mut(db) { "mut " } else { "" }; let is_mut = if it.is_mut(db) { "mut " } else { "" };
let desc = match it.source(db).value { let desc = match it.primary_source(db).into_ident_pat() {
Either::Left(ident) => { Some(ident) => {
let name = it.name(db); let name = it.name(db);
let let_kw = if ident let let_kw = if ident
.syntax() .syntax()
@ -649,7 +649,7 @@ fn local(db: &RootDatabase, it: hir::Local) -> Option<Markup> {
}; };
format!("{let_kw}{is_mut}{name}: {ty}") format!("{let_kw}{is_mut}{name}: {ty}")
} }
Either::Right(_) => format!("{is_mut}self: {ty}"), None => format!("{is_mut}self: {ty}"),
}; };
markup(None, desc, None) markup(None, desc, None)
} }

View file

@ -5,7 +5,7 @@ use std::fmt;
use either::Either; use either::Either;
use hir::{ use hir::{
symbols::FileSymbol, AssocItem, Documentation, FieldSource, HasAttrs, HasSource, HirDisplay, symbols::FileSymbol, AssocItem, Documentation, FieldSource, HasAttrs, HasSource, HirDisplay,
InFile, ModuleSource, Semantics, InFile, LocalSource, ModuleSource, Semantics,
}; };
use ide_db::{ use ide_db::{
base_db::{FileId, FileRange}, base_db::{FileId, FileRange},
@ -387,9 +387,11 @@ impl TryToNav for hir::GenericParam {
} }
} }
impl ToNav for hir::Local { impl ToNav for LocalSource {
fn to_nav(&self, db: &RootDatabase) -> NavigationTarget { fn to_nav(&self, db: &RootDatabase) -> NavigationTarget {
let InFile { file_id, value } = self.source(db); let InFile { file_id, value } = &self.source;
let file_id = *file_id;
let local = self.local;
let (node, name) = match &value { let (node, name) = match &value {
Either::Left(bind_pat) => (bind_pat.syntax(), bind_pat.name()), Either::Left(bind_pat) => (bind_pat.syntax(), bind_pat.name()),
Either::Right(it) => (it.syntax(), it.name()), Either::Right(it) => (it.syntax(), it.name()),
@ -398,10 +400,10 @@ impl ToNav for hir::Local {
let FileRange { file_id, range: full_range } = let FileRange { file_id, range: full_range } =
InFile::new(file_id, node).original_file_range(db); InFile::new(file_id, node).original_file_range(db);
let name = self.name(db).to_smol_str(); let name = local.name(db).to_smol_str();
let kind = if self.is_self(db) { let kind = if local.is_self(db) {
SymbolKind::SelfParam SymbolKind::SelfParam
} else if self.is_param(db) { } else if local.is_param(db) {
SymbolKind::ValueParam SymbolKind::ValueParam
} else { } else {
SymbolKind::Local SymbolKind::Local
@ -419,6 +421,12 @@ impl ToNav for hir::Local {
} }
} }
impl ToNav for hir::Local {
fn to_nav(&self, db: &RootDatabase) -> NavigationTarget {
self.primary_source(db).to_nav(db)
}
}
impl ToNav for hir::Label { impl ToNav for hir::Label {
fn to_nav(&self, db: &RootDatabase) -> NavigationTarget { fn to_nav(&self, db: &RootDatabase) -> NavigationTarget {
let InFile { file_id, value } = self.source(db); let InFile { file_id, value } = self.source(db);