1use std::collections::HashSet;
4use tinymist_derive::BindTyCtx;
5
6use super::{
7 ArgsTy, Sig, SigChecker, SigShape, SigSurfaceKind, SigTy, Ty, TyCtx, TyCtxMut, TypeBounds,
8 TypeInfo, TypeVar,
9};
10use super::{DynTypeBounds, ParamAttrs, ParamTy, SharedContext, prelude::*};
11use crate::syntax::{ArgClass, SyntaxContext, VarClass, classify_context, classify_context_outer};
12use crate::ty::{BuiltinTy, RecordTy};
13
14#[typst_macros::time(span = node.span())]
17pub(crate) fn post_type_check(
18 ctx: Arc<SharedContext>,
19 ti: &TypeInfo,
20 node: LinkedNode,
21) -> Option<Ty> {
22 let mut checker = PostTypeChecker::new(ctx, ti);
23 let res = checker.check(&node);
24 checker.simplify(&res?)
25}
26
27#[derive(Default)]
28struct SignatureReceiver {
29 lbs_dedup: HashSet<Ty>,
30 ubs_dedup: HashSet<Ty>,
31 bounds: TypeBounds,
32}
33
34impl SignatureReceiver {
35 fn insert(&mut self, ty: Ty, pol: bool) {
36 crate::log_debug_ct!("post check receive: {ty:?}");
37 if !pol {
38 if self.lbs_dedup.insert(ty.clone()) {
39 self.bounds.lbs.push(ty);
40 }
41 } else if self.ubs_dedup.insert(ty.clone()) {
42 self.bounds.ubs.push(ty);
43 }
44 }
45
46 fn finalize(self) -> Ty {
47 Ty::Let(self.bounds.into())
48 }
49}
50
51fn check_signature<'a>(
52 receiver: &'a mut SignatureReceiver,
53 arg: &'a ArgClass,
54) -> impl FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()> + 'a {
55 move |worker, sig, args, pol| {
56 let (sig, _is_partialize) = match sig {
57 Sig::Partialize(sig) => (*sig, true),
58 sig => (sig, false),
59 };
60
61 let SigShape { sig: sig_ins, .. } = sig.shape(worker)?;
62
63 match &arg {
64 ArgClass::Named(n) => {
65 let ident = n.cast::<ast::Ident>()?;
66 let ty = sig_ins.named(&ident.into())?;
67 receiver.insert(ty.clone(), !pol);
68
69 Some(())
70 }
71 ArgClass::Positional {
72 spreads: _,
74 positional,
75 is_spread,
76 } => {
77 if *is_spread {
78 return None;
79 }
80
81 let bound_pos = args
83 .iter()
84 .map(|args| args.positional_params().len())
85 .sum::<usize>();
86 if let Some(nth) = sig_ins.pos_or_rest(bound_pos + positional) {
87 receiver.insert(nth, !pol);
88 }
89
90 for (name, ty) in sig_ins.named_params() {
92 let field = ParamTy::new(ty.clone(), name.clone(), ParamAttrs::named());
93 receiver.insert(Ty::Param(field), !pol);
94 }
95
96 Some(())
97 }
98 }
99 }
100}
101
102pub(crate) struct PostTypeChecker<'a> {
103 ctx: Arc<SharedContext>,
104 pub info: &'a TypeInfo,
105 checked: HashMap<Span, Option<Ty>>,
106 locals: TypeInfo,
107}
108
109impl TyCtx for PostTypeChecker<'_> {
110 fn global_bounds(&self, var: &Interned<TypeVar>, pol: bool) -> Option<DynTypeBounds> {
111 self.info.global_bounds(var, pol)
112 }
113
114 fn local_bind_of(&self, var: &Interned<TypeVar>) -> Option<Ty> {
115 self.locals.local_bind_of(var)
116 }
117}
118
119impl TyCtxMut for PostTypeChecker<'_> {
120 type Snap = <TypeInfo as TyCtxMut>::Snap;
121
122 fn start_scope(&mut self) -> Self::Snap {
123 self.locals.start_scope()
124 }
125
126 fn end_scope(&mut self, snap: Self::Snap) {
127 self.locals.end_scope(snap)
128 }
129
130 fn bind_local(&mut self, var: &Interned<TypeVar>, ty: Ty) {
131 self.locals.bind_local(var, ty);
132 }
133
134 fn type_of_func(&mut self, func: &Func) -> Option<Interned<SigTy>> {
135 Some(self.ctx.type_of_func(func.clone()).type_sig())
136 }
137
138 fn type_of_value(&mut self, val: &Value) -> Ty {
139 self.ctx.type_of_value(val)
140 }
141
142 fn check_module_item(&mut self, _module: TypstFileId, _key: &StrRef) -> Option<Ty> {
143 None
144 }
145}
146
147impl<'a> PostTypeChecker<'a> {
148 pub fn new(ctx: Arc<SharedContext>, info: &'a TypeInfo) -> Self {
149 Self {
150 ctx,
151 info,
152 checked: HashMap::new(),
153 locals: TypeInfo::default(),
154 }
155 }
156
157 fn check(&mut self, node: &LinkedNode) -> Option<Ty> {
158 let span = node.span();
159 if let Some(ty) = self.checked.get(&span) {
160 return ty.clone();
161 }
162 self.checked.insert(span, None);
164
165 let ty = self.check_(node);
166 self.checked.insert(span, ty.clone());
167 ty
168 }
169
170 fn simplify(&mut self, ty: &Ty) -> Option<Ty> {
171 Some(self.info.simplify(ty.clone(), false))
172 }
173
174 fn check_(&mut self, node: &LinkedNode) -> Option<Ty> {
175 let context = node.parent()?;
176 crate::log_debug_ct!("post check: {:?}::{:?}", context.kind(), node.kind());
177
178 let context_ty = self.check_context(context, node);
179 let self_ty = if !matches!(node.kind(), SyntaxKind::Label | SyntaxKind::Ref) {
180 self.info.type_of_span(node.span())
181 } else {
182 None
183 };
184
185 let can_penetrate_context = !(matches!(
186 node.kind(),
187 SyntaxKind::Hash | SyntaxKind::ContentBlock | SyntaxKind::CodeBlock
188 ) || matches!(context.kind(), SyntaxKind::FieldAccess) && {
189 let field_access = context.cast::<ast::FieldAccess>()?;
190 field_access.field().span() == node.span()
191 });
192
193 let contextual_self_ty = can_penetrate_context
194 .then(|| self.check_cursor(classify_context(node.clone(), None), context_ty));
195 crate::log_debug_ct!(
196 "post check(res): {:?}::{:?} -> {self_ty:?}, {contextual_self_ty:?}",
197 context.kind(),
198 node.kind(),
199 );
200
201 Ty::union(self_ty, contextual_self_ty.flatten())
202 }
203
204 fn check_or(&mut self, node: &LinkedNode, ty: Option<Ty>) -> Option<Ty> {
205 Ty::union(self.check(node), ty)
206 }
207
208 fn check_cursor(
209 &mut self,
210 cursor: Option<SyntaxContext>,
211 context_ty: Option<Ty>,
212 ) -> Option<Ty> {
213 let Some(cursor) = cursor else {
214 return context_ty;
215 };
216 crate::log_debug_ct!("post check target: {cursor:?}");
217
218 match &cursor {
219 SyntaxContext::Arg {
220 callee,
221 args: _,
222 target,
223 is_set,
224 } => {
225 let callee_ty = self.check_or(callee, context_ty)?;
226 crate::log_debug_ct!(
227 "post check call target: ({callee_ty:?})::{target:?} is_set: {is_set}"
228 );
229
230 let sig = self.ctx.sig_of_type_or_dyn(self.info, callee_ty, callee)?;
231 crate::log_debug_ct!("post check call sig: {target:?} {sig:?}");
232 let mut resp = SignatureReceiver::default();
233
234 match target {
235 ArgClass::Named(n) => {
236 let ident = n.cast::<ast::Ident>()?.into();
237 let ty = sig.primary().get_named(&ident)?;
238 resp.insert(ty.ty.clone(), false);
240 }
241 ArgClass::Positional {
242 spreads: _,
244 positional,
245 is_spread,
246 } => {
247 if *is_spread {
248 return None;
249 }
250
251 let shift = sig.param_shift();
253 let nth = sig
254 .primary()
255 .get_pos(shift + positional)
256 .or_else(|| sig.primary().rest());
257 if let Some(nth) = nth {
258 resp.insert(Ty::Param(nth.clone()), false);
259 }
260
261 for field in sig.primary().named() {
263 if *is_set && !field.attrs.settable {
264 continue;
265 }
266
267 resp.insert(Ty::Param(field.clone()), false);
268 }
269 }
270 }
271
272 crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
273 Some(resp.finalize())
274 }
275 SyntaxContext::Element { container, target } => {
276 let container_expr = match container.kind() {
284 SyntaxKind::Array | SyntaxKind::Dict => container
285 .parent()
286 .cloned()
287 .filter(|p| p.kind() == SyntaxKind::Parenthesized)
288 .unwrap_or_else(|| container.clone()),
289 _ => container.clone(),
290 };
291 let container_ty = self.check_or(&container_expr, context_ty)?;
292 crate::log_debug_ct!("post check element target: ({container_ty:?})::{target:?}");
293
294 let mut resp = SignatureReceiver::default();
295
296 self.check_element_of(
297 &container_ty,
298 false,
299 container,
300 &mut check_signature(&mut resp, target),
301 );
302
303 crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
304 Some(resp.finalize())
305 }
306 SyntaxContext::Paren {
307 container,
308 is_before,
309 } => {
310 let container_ty = self.check_or(container, context_ty)?;
311 crate::log_debug_ct!("post check paren target: {container_ty:?}::{is_before:?}");
312
313 let mut resp = SignatureReceiver::default();
314 resp.bounds.lbs.push(container_ty.clone());
317
318 let target = ArgClass::first_positional();
319 self.check_element_of(
320 &container_ty,
321 false,
322 container,
323 &mut check_signature(&mut resp, &target),
324 );
325
326 crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
327 Some(resp.finalize())
328 }
329 SyntaxContext::ImportPath(..) | SyntaxContext::IncludePath(..) => {
330 Some(Ty::Builtin(BuiltinTy::Path(crate::ty::PathKind::Source {
331 allow_package: true,
332 })))
333 }
334 SyntaxContext::VarAccess(VarClass::Ident(node))
335 | SyntaxContext::VarAccess(VarClass::FieldAccess(node))
336 | SyntaxContext::VarAccess(VarClass::DotAccess(node))
337 | SyntaxContext::Label { node, .. }
338 | SyntaxContext::Ref { node, .. }
339 | SyntaxContext::At { node, .. }
340 | SyntaxContext::Normal(node) => {
341 let label_or_ref_ty = match cursor {
342 SyntaxContext::Label { is_error: true, .. } => {
343 Some(Ty::Builtin(BuiltinTy::Label))
344 }
345 SyntaxContext::Ref {
346 suffix_colon: true, ..
347 } => Some(Ty::Builtin(BuiltinTy::RefLabel)),
348 _ => None,
349 };
350 let ty = self.check_or(node, context_ty);
351 crate::log_debug_ct!("post check target normal: {ty:?} {label_or_ref_ty:?}");
352 ty.or(label_or_ref_ty)
353 }
354 }
355 }
356
357 fn check_context(&mut self, context: &LinkedNode, node: &LinkedNode) -> Option<Ty> {
359 match context.kind() {
360 SyntaxKind::LetBinding => {
361 let let_binding = context.cast::<ast::LetBinding>()?;
362 let let_init = let_binding.init()?;
363 let let_init_node = context.find(let_init.span())?;
364 let_init_node.find(node.span())?;
365
366 match let_binding.kind() {
367 ast::LetBindingKind::Closure(_c) => None,
368 ast::LetBindingKind::Normal(pattern) => self.check_let_pattern(pattern),
369 }
370 }
371 SyntaxKind::Args => self.check_cursor(
372 classify_context_outer(context.clone(), node.clone()),
374 None,
375 ),
376 SyntaxKind::Named => self.check_cursor(classify_context(context.clone(), None), None),
378 _ => None,
379 }
380 }
381
382 fn check_let_pattern(&mut self, pattern: ast::Pattern) -> Option<Ty> {
384 match pattern {
385 ast::Pattern::Placeholder(_) => None,
386 ast::Pattern::Normal(ast::Expr::Ident(ident)) => self.info.type_of_span(ident.span()),
387 ast::Pattern::Normal(..) => None,
388 ast::Pattern::Parenthesized(expr) => {
389 self.check_let_pattern(expr.expr().to_untyped().cast()?)
390 }
391 ast::Pattern::Destructuring(expr) => self.check_let_destruct(expr),
392 }
393 }
394
395 fn check_let_destruct(&mut self, destructuring: ast::Destructuring) -> Option<Ty> {
397 let mut pos = vec![];
398 let mut named = vec![];
399 for item in destructuring.items() {
400 match item {
401 ast::DestructuringItem::Pattern(pat) => {
402 pos.push(self.check_let_pattern(pat).unwrap_or(Ty::Any));
403 }
404 ast::DestructuringItem::Named(named_item) => {
405 let key = Interned::new_str(named_item.name().get().as_str());
406 let ty = self
407 .check_let_pattern(named_item.pattern())
408 .unwrap_or(Ty::Any);
409 named.push((key, ty));
410 }
411 ast::DestructuringItem::Spread(..) => {}
414 }
415 }
416
417 let tuple_ty = (!pos.is_empty()).then(|| Ty::Tuple(pos.into()));
418 let dict_ty = (!named.is_empty()).then(|| Ty::Dict(RecordTy::new(named)));
419 Ty::union(tuple_ty, dict_ty)
420 }
421
422 fn check_element_of<T>(&mut self, ty: &Ty, pol: bool, context: &LinkedNode, checker: &mut T)
423 where
424 T: PostSigChecker,
425 {
426 let mut checker = PostSigCheckWorker(self, checker);
427 ty.sig_surface(pol, sig_context_of(context), &mut checker)
428 }
429}
430
431trait PostSigChecker {
432 fn check(
433 &mut self,
434 checker: &mut PostTypeChecker,
435 sig: Sig,
436 args: &[Interned<ArgsTy>],
437 pol: bool,
438 ) -> Option<()>;
439}
440
441impl<T> PostSigChecker for T
442where
443 T: FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()>,
444{
445 fn check(
446 &mut self,
447 checker: &mut PostTypeChecker,
448 sig: Sig,
449 args: &[Interned<ArgsTy>],
450 pol: bool,
451 ) -> Option<()> {
452 self(checker, sig, args, pol)
453 }
454}
455
456#[derive(BindTyCtx)]
457#[bind(0)]
458struct PostSigCheckWorker<'x, 'a, T>(&'x mut PostTypeChecker<'a>, &'x mut T);
459
460impl<T: PostSigChecker> SigChecker for PostSigCheckWorker<'_, '_, T> {
461 fn check(
462 &mut self,
463 sig: Sig,
464 args: &mut crate::analysis::SigCheckContext,
465 pol: bool,
466 ) -> Option<()> {
467 self.1.check(self.0, sig, &args.args, pol)
468 }
469}
470
471fn sig_context_of(context: &LinkedNode) -> SigSurfaceKind {
472 match context.kind() {
473 SyntaxKind::Parenthesized => SigSurfaceKind::ArrayOrDict,
474 SyntaxKind::Array => {
475 let arr = context.cast::<ast::Array>();
476 if arr.is_some_and(|arr| arr.items().next().is_some()) {
477 SigSurfaceKind::Array
478 } else {
479 SigSurfaceKind::ArrayOrDict
480 }
481 }
482 SyntaxKind::Dict => SigSurfaceKind::Dict,
483 _ => SigSurfaceKind::Array,
484 }
485}