1 /**
2    Support the automatic implementation of test doubles via programmable mocks.
3  */
4 module unit_threaded.mock;
5 
6 import unit_threaded.from;
7 
8 alias Identity(alias T) = T;
9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
10 
11 
12 string implMixinStr(T)() {
13     import std.array: join;
14     import std.format : format;
15     import std.range : iota;
16     import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity;
17     import std.conv: text;
18 
19     if(!__ctfe) return null;
20 
21     string[] lines;
22 
23     string getOverload(in string memberName, in int i) {
24         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
25             .format(memberName, i);
26     }
27 
28     foreach(memberName; __traits(allMembers, T)) {
29 
30         static if(!isPrivate!(T, memberName)) {
31 
32             alias member = Identity!(__traits(getMember, T, memberName));
33 
34             static if(__traits(isVirtualMethod, member)) {
35                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
36 
37                     static if(!(functionAttributes!member & FunctionAttribute.const_) &&
38                               !(functionAttributes!member & FunctionAttribute.const_)) {
39 
40                         enum overloadName = text(memberName, "_", i);
41 
42                         enum overloadString = getOverload(memberName, i);
43                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, overloadString);
44                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, overloadString);
45 
46                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
47                             enum tryIndent = "    ";
48                         else
49                             enum tryIndent = "";
50 
51                         static if(is(ReturnType!member == void))
52                             enum returnDefault = "";
53                         else {
54                             enum varName = overloadName ~ `_returnValues`;
55                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
56                             lines ~= "";
57                             enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
58                                                   `        auto ret = ` ~ varName ~ `[0];`,
59                                                   `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
60                                                   `        return ret;`,
61                                                   `    } else`,
62                                                   `        return %s_returnType.init;`.format(overloadName)];
63                         }
64 
65                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
66                             typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~
67                             functionAttributesString!member ~ ` {`;
68 
69                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
70                             lines ~= "try {";
71 
72                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
73                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~ argNamesParens(arity!member) ~ `.to!string;`;
74 
75                         static if(functionAttributes!member & FunctionAttribute.nothrow_)
76                             lines ~= "    } catch(Exception) {}";
77 
78                         lines ~= returnDefault;
79 
80                         lines ~= `}`;
81                         lines ~= "";
82                     }
83                 }
84             }
85         }
86     }
87 
88     return lines.join("\n");
89 }
90 
91 private string argNamesParens(int N) @safe pure {
92     if(!__ctfe) return null;
93     return "(" ~ argNames(N) ~ ")";
94 }
95 
96 private string argNames(int N) @safe pure {
97     import std.range;
98     import std.algorithm;
99     import std.conv;
100 
101     if(!__ctfe) return null;
102     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
103 }
104 
105 private string typeAndArgsParens(T...)(string prefix) {
106     import std.array;
107     import std.conv;
108     import std.format : format;
109 
110     if(!__ctfe) return null;
111 
112     string[] parts;
113 
114     foreach(i, t; T)
115         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
116     return "(" ~ parts.join(", ") ~ ")";
117 }
118 
119 private string functionAttributesString(alias F)() {
120     import std.traits: functionAttributes, FunctionAttribute;
121     import std.array: join;
122 
123     if(!__ctfe) return null;
124 
125     string[] parts;
126 
127     const attrs = functionAttributes!F;
128 
129     if(attrs & FunctionAttribute.pure_) parts ~= "pure";
130     if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow";
131     if(attrs & FunctionAttribute.trusted) parts ~= "@trusted";
132     if(attrs & FunctionAttribute.safe) parts ~= "@safe";
133     if(attrs & FunctionAttribute.nogc) parts ~= "@nogc";
134     if(attrs & FunctionAttribute.system) parts ~= "@system";
135     // const and immutable can't be done since the mock needs
136     // to alter state
137     // if(attrs & FunctionAttribute.const_) parts ~= "const";
138     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
139     if(attrs & FunctionAttribute.shared_) parts ~= "shared";
140 
141     return parts.join(" ");
142 }
143 
144 mixin template MockImplCommon() {
145     bool _verified;
146     string[] expectedFuncs;
147     string[] calledFuncs;
148     string[] expectedValues;
149     string[] calledValues;
150 
151     void expect(string funcName, V...)(auto ref V values) {
152         import std.conv: to;
153         import std.typecons: tuple;
154 
155         expectedFuncs ~= funcName;
156         static if(V.length > 0)
157             expectedValues ~= tuple(values).to!string;
158         else
159             expectedValues ~= "";
160     }
161 
162     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
163         expect!func(values);
164         verify(file, line);
165         _verified = false;
166     }
167 
168     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
169         import std.range: repeat, take, join;
170         import std.conv: to;
171         import unit_threaded.should: fail, UnitTestException;
172 
173         if(_verified)
174             fail("Mock already _verified", file, line);
175 
176         _verified = true;
177 
178         for(int i = 0; i < expectedFuncs.length; ++i) {
179 
180             if(i >= calledFuncs.length)
181                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
182 
183             if(expectedFuncs[i] != calledFuncs[i])
184                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
185                      " instead",
186                      file, line);
187 
188             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
189                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
190                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
191                                              "instead of the expected " ~ expectedValues[i]] ,
192                                             file, line);
193         }
194     }
195 }
196 
197 private enum isString(alias T) = is(typeof(T) == string);
198 
199 /**
200    A mock object that conforms to an interface/class.
201  */
202 struct Mock(T) {
203 
204     MockAbstract _impl;
205     alias _impl this;
206 
207     class MockAbstract: T {
208         import std.conv: to;
209         import std.traits: Parameters, ReturnType;
210         import std.typecons: tuple;
211 
212         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
213         mixin(implMixinStr!T);
214         mixin MockImplCommon;
215     }
216 
217     ///
218     this(int/* force constructor*/) {
219         _impl = new MockAbstract;
220     }
221 
222     ///
223     ~this() pure @safe {
224         if(!_verified) verify;
225     }
226 
227     /// Set the returnValue of a function to certain values.
228     void returnValue(string funcName, V...)(V values) {
229         assertFunctionIsVirtual!funcName;
230         return returnValue!(0, funcName)(values);
231     }
232 
233     /**
234        This version takes overloads into account. i is the overload
235        index. e.g.:
236        ---------
237        interface Interface { void foo(int); void foo(string); }
238        auto m = mock!Interface;
239        m.returnValue!(0, "foo"); // int overload
240        m.returnValue!(1, "foo"); // string overload
241        ---------
242      */
243     void returnValue(int i, string funcName, V...)(V values) {
244         assertFunctionIsVirtual!funcName;
245         import std.conv: text;
246         enum varName = funcName ~ text(`_`, i, `_returnValues`);
247         foreach(v; values)
248             mixin(varName ~ ` ~=  v;`);
249     }
250 
251     private static void assertFunctionIsVirtual(string funcName)() {
252         alias member = Identity!(__traits(getMember, T, funcName));
253 
254         static assert(__traits(isVirtualMethod, member),
255                       "Cannot use returnValue on '" ~ funcName ~ "'");
256     }
257 }
258 
259 private string importsString(string module_, string[] Modules...) {
260     if(!__ctfe) return null;
261 
262     auto ret = `import ` ~ module_ ~ ";\n";
263     foreach(extraModule; Modules) {
264         ret ~= `import ` ~ extraModule ~ ";\n";
265     }
266     return ret;
267 }
268 
269 /// Helper function for creating a Mock object.
270 auto mock(T)() {
271     return Mock!T(0);
272 }
273 
274 ///
275 @("mock interface positive test no params")
276 @safe pure unittest {
277     interface Foo {
278         int foo(int, string) @safe pure;
279         void bar() @safe pure;
280     }
281 
282     int fun(Foo f) {
283         return 2 * f.foo(5, "foobar");
284     }
285 
286     auto m = mock!Foo;
287     m.expect!"foo";
288     fun(m);
289 }
290 
291 
292 ///
293 @("mock interface positive test with params")
294 @safe pure unittest {
295     import unit_threaded.asserts;
296 
297     interface Foo {
298         int foo(int, string) @safe pure;
299         void bar() @safe pure;
300     }
301 
302     int fun(Foo f) {
303         return 2 * f.foo(5, "foobar");
304     }
305 
306     auto m = mock!Foo;
307     m.expect!"foo"(5, "foobar");
308     fun(m);
309 }
310 
311 
312 ///
313 @("interface expectCalled")
314 @safe pure unittest {
315     interface Foo {
316         int foo(int, string) @safe pure;
317         void bar() @safe pure;
318     }
319 
320     int fun(Foo f) {
321         return 2 * f.foo(5, "foobar");
322     }
323 
324     auto m = mock!Foo;
325     fun(m);
326     m.expectCalled!"foo"(5, "foobar");
327 }
328 
329 ///
330 @("interface return value")
331 @safe pure unittest {
332 
333     interface Foo {
334         int timesN(int i) @safe pure;
335     }
336 
337     int fun(Foo f) {
338         return f.timesN(3) * 2;
339     }
340 
341     auto m = mock!Foo;
342     m.returnValue!"timesN"(42);
343     immutable res = fun(m);
344     assert(res == 84);
345 }
346 
347 ///
348 @("interface return values")
349 @safe pure unittest {
350 
351     interface Foo {
352         int timesN(int i) @safe pure;
353     }
354 
355     int fun(Foo f) {
356         return f.timesN(3) * 2;
357     }
358 
359     auto m = mock!Foo;
360     m.returnValue!"timesN"(42, 12);
361     assert(fun(m) == 84);
362     assert(fun(m) == 24);
363     assert(fun(m) == 0);
364 }
365 
366 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T)) {
367     alias funcName = function_;
368     alias Values = T;
369 
370     static auto values() {
371         typeof(T[0])[] ret;
372         foreach(val; T) {
373             ret ~= val;
374         }
375         return ret;
376     }
377 }
378 
379 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
380 enum isValue(alias T) = is(typeof(T));
381 
382 
383 /**
384    Version of mockStruct that accepts 0 or more values of the same
385    type. Whatever function is called on it, these values will
386    be returned one by one. The limitation is that if more than one
387    function is called on the mock, they all return the same type
388  */
389 auto mockStruct(T...)(auto ref T returns) {
390 
391     struct Mock {
392 
393         MockImpl* _impl;
394         alias _impl this;
395 
396         static struct MockImpl {
397 
398             static if(T.length > 0) {
399                 alias FirstType = typeof(returns[0]);
400                 private FirstType[] _returnValues;
401             }
402 
403             mixin MockImplCommon;
404 
405             auto opDispatch(string funcName, V...)(auto ref V values) {
406 
407                 import std.conv: to;
408                 import std.typecons: tuple;
409 
410                 calledFuncs ~= funcName;
411                 calledValues ~= tuple(values).to!string;
412 
413                 static if(T.length > 0) {
414 
415                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
416                     auto ret = _returnValues[0];
417                     _returnValues = _returnValues[1..$];
418                     return ret;
419                 }
420             }
421         }
422     }
423 
424     Mock m;
425     m._impl = new Mock.MockImpl;
426     static if(T.length > 0) {
427         foreach(r; returns)
428             m._impl._returnValues ~= r;
429     }
430 
431     return m;
432 }
433 
434 /**
435    Version of mockStruct that accepts a compile-time mapping
436    of function name to return values. Each template parameter
437    must be a value of type `ReturnValues`
438  */
439 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) {
440 
441     struct Mock {
442         mixin MockImplCommon;
443 
444         int[string] _retIndices;
445 
446         auto opDispatch(string funcName, V...)(auto ref V values) {
447 
448             import std.conv: to;
449             import std.typecons: tuple;
450 
451             calledFuncs ~= funcName;
452             calledValues ~= tuple(values).to!string;
453 
454             foreach(retVal; T) {
455                 static if(retVal.funcName == funcName) {
456                     return retVal.values[_retIndices[funcName]++];
457                 }
458             }
459         }
460 
461         auto lefoofoo() {
462             return T[0].values[_retIndices["greet"]++];
463         }
464 
465     }
466 
467     Mock mock;
468 
469     foreach(retVal; T) {
470         mock._retIndices[retVal.funcName] = 0;
471     }
472 
473     return mock;
474 }
475 
476 ///
477 @("mock struct positive")
478 @safe pure unittest {
479     void fun(T)(T t) {
480         t.foobar;
481     }
482     auto m = mockStruct;
483     m.expect!"foobar";
484     fun(m);
485     m.verify;
486 }
487 
488 
489 ///
490 @("mock struct values positive")
491 @safe pure unittest {
492     void fun(T)(T t) {
493         t.foobar(2, "quux");
494     }
495 
496     auto m = mockStruct;
497     m.expect!"foobar"(2, "quux");
498     fun(m);
499     m.verify;
500 }
501 
502 
503 ///
504 @("struct return value")
505 @safe pure unittest {
506 
507     int fun(T)(T f) {
508         return f.timesN(3) * 2;
509     }
510 
511     auto m = mockStruct(42, 12);
512     assert(fun(m) == 84);
513     assert(fun(m) == 24);
514     assert(fun(m) == 0);
515     m.expectCalled!"timesN";
516 }
517 
518 ///
519 @("struct expectCalled")
520 @safe pure unittest {
521     void fun(T)(T t) {
522         t.foobar(2, "quux");
523     }
524 
525     auto m = mockStruct;
526     fun(m);
527     m.expectCalled!"foobar"(2, "quux");
528 }
529 
530 ///
531 @("mockStruct different return types for different functions")
532 @safe pure unittest {
533     auto m = mockStruct!(ReturnValues!("length", 5),
534                          ReturnValues!("greet", "hello"));
535     assert(m.length == 5);
536     assert(m.greet("bar") == "hello");
537     m.expectCalled!"length";
538     m.expectCalled!"greet"("bar");
539 }
540 
541 ///
542 @("mockStruct different return types for different functions and multiple return values")
543 @safe pure unittest {
544     auto m = mockStruct!(ReturnValues!("length", 5, 3),
545                          ReturnValues!("greet", "hello", "g'day"));
546     assert(m.length == 5);
547     m.expectCalled!"length";
548     assert(m.length == 3);
549     m.expectCalled!"length";
550 
551     assert(m.greet("bar") == "hello");
552     m.expectCalled!"greet"("bar");
553     assert(m.greet("quux") == "g'day");
554     m.expectCalled!"greet"("quux");
555 }
556 
557 
558 /**
559    A mock struct that always throws.
560  */
561 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() {
562 
563     struct Mock {
564 
565         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
566                     (auto ref V values) {
567             throw new E(funcName ~ " was called", file, line);
568         }
569     }
570 
571     return Mock();
572 }
573 
574 ///
575 @("throwStruct default")
576 @safe pure unittest {
577     import std.exception: assertThrown;
578     import unit_threaded.should: UnitTestException;
579     auto m = throwStruct;
580     assertThrown!UnitTestException(m.foo);
581     assertThrown!UnitTestException(m.bar(1, "foo"));
582 }