1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
use syn::{
    ItemTrait, TraitItem, PathArguments, TypeParamBound,
    TraitBoundModifier, punctuated::Punctuated, token::Add,
};
use quote::{ToTokens, quote};

use crate::{
    attr::{MsgAttr, ERROR_TYPE},
    err::{ErrorSink, CompileErrors},
    generate::{self, MsgType},
    method::{Method, trait_methods},
    validate
};

pub const SUPPORTED_ATTRS: [&'static str; 3] = [
    MsgAttr::EXECUTE,
    MsgAttr::QUERY,
    MsgAttr::INIT
];

pub fn derive(r#trait: ItemTrait) -> Result<proc_macro2::TokenStream, CompileErrors> {
    let mut sink = ErrorSink::default();
    let interface = Interface::parse(&mut sink, &r#trait);

    let init_msg = interface.init.and_then(|x|
        Some(generate::init_msg(&mut sink, &x)
            .to_token_stream()
        ))
        .unwrap_or(proc_macro2::TokenStream::new());

    let execute_msg = generate::messages(
        &mut sink,
        MsgType::Execute,
        &interface.execute
    );
    let query_msg = generate::messages(
        &mut sink,
        MsgType::Query,
        &interface.query
    );

    sink.check()?;

    Ok(quote! {
        #init_msg
        #execute_msg
        #query_msg
    })
}

#[inline]
pub fn is_valid_attr(attr: MsgAttr) -> bool {
    SUPPORTED_ATTRS.contains(&attr.as_str())
}

struct Interface<'a> {
    /// Optional because an interface might not want to have an init method.
    init: Option<Method<'a>>,
    execute: Vec<Method<'a>>,
    query: Vec<Method<'a>>
}

impl<'a> Interface<'a> {
    fn parse(sink: &mut ErrorSink, r#trait: &'a ItemTrait) -> Self {
        let trait_ident = &r#trait.ident;
        let mut init: Option<Method> = None;
        let mut execute: Vec<Method> = vec![];
        let mut query: Vec<Method> = vec![];

        // We forbid generic traits because they will complicate the error type on contracts.
        if validate::has_generics(&r#trait.generics) {
            sink.push_spanned(
                &r#trait,
                "Interface traits cannot have any generics."
            );
        }

        let err_ty = r#trait.items.iter().find_map(|x| {
            match x {
                TraitItem::Type(type_def)
                    if type_def.ident.to_string() == ERROR_TYPE => 
                {
                    Some(type_def)
                }
                _ => None
            }
        });

        if let Some(err_ty) = err_ty {
            if !validate_err_bound(&err_ty.bounds) {
                sink.push_spanned(
                    &err_ty,
                    format!("{} type must have a single \"std::fmt::Display\" bound.", ERROR_TYPE)
                );
            }

            if validate::has_generics(&err_ty.generics) {
                sink.push_spanned(
                    &err_ty.generics,
                    format!("{} type cannot have any generics.", ERROR_TYPE)
                );
            }
        } else {
            sink.push_spanned(
                trait_ident,
                format!("Missing \"type {}: std::fmt::Display;\" trait type declaration.", ERROR_TYPE)
            );
        }

        for method in trait_methods(sink, r#trait) {
            let ty = method.ty;
            match ty {
                MsgAttr::Init { .. } if init.is_some() =>
                    sink.duplicate_annotation(trait_ident, ty),
                MsgAttr::Init { entry } => {
                    if entry.is_some() {
                        sink.push_spanned(&method.sig, "Interfaces cannot have entry points.");
                    }

                    init = Some(Method::Interface(method));
                }
                MsgAttr::Execute => execute.push(Method::Interface(method)),
                MsgAttr::Query => query.push(Method::Interface(method)),
                unsupported => sink.unsupported_interface_attr(
                    &method.sig.ident,
                    unsupported
                )
            }
        }

        Self {
            init,
            execute,
            query
        }
    }
}

fn validate_err_bound(bounds: &Punctuated<TypeParamBound, Add>) -> bool {
    if bounds.len() != 1 {
        return false;
    }

    let TypeParamBound::Trait(bound) = bounds.first().unwrap() else {
        return false;
    };

    if !matches!(bound.modifier, TraitBoundModifier::None) ||
        bound.lifetimes.is_some()
    {
        return false;
    }

    let Some(segment) = bound.path.segments.last() else {
        return false;
    };

    segment.ident.to_string() == "Display" &&
        segment.arguments == PathArguments::None
}