1 /**
2    Support the automatic implementation of test doubles via programmable mocks.
3  */
4 module unit_threaded.mock;
5 
6 import unit_threaded.from;
7 
8 alias Identity(alias T) = T;
9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
10 
11 
12 string implMixinStr(T)() {
13     import std.array: join;
14     import std.format : format;
15     import std.range : iota;
16     import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity;
17     import std.conv: text;
18 
19     if(!__ctfe) return null;
20 
21     string[] lines;
22 
23     string getOverload(in string memberName, in int i) {
24         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
25             .format(memberName, i);
26     }
27 
28     foreach(memberName; __traits(allMembers, T)) {
29 
30         static if(!isPrivate!(T, memberName)) {
31 
32             alias member = Identity!(__traits(getMember, T, memberName));
33 
34             static if(__traits(isVirtualMethod, member)) {
35                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
36 
37                     static if(!(functionAttributes!overload & FunctionAttribute.const_) &&
38                               !(functionAttributes!overload & FunctionAttribute.const_)) {
39 
40                         enum overloadName = text(memberName, "_", i);
41 
42                         enum overloadString = getOverload(memberName, i);
43                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(
44                             overloadName, overloadString);
45                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(
46                             overloadName, overloadString);
47 
48                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
49                             enum tryIndent = "    ";
50                         else
51                             enum tryIndent = "";
52 
53                         static if(is(ReturnType!overload == void))
54                             enum returnDefault = "";
55                         else {
56                             enum varName = overloadName ~ `_returnValues`;
57                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
58                             lines ~= "";
59                             enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
60                                                   `        auto ret = ` ~ varName ~ `[0];`,
61                                                   `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
62                                                   `        return ret;`,
63                                                   `    } else`,
64                                                   `        return %s_returnType.init;`.format(
65                                                       overloadName)];
66                         }
67 
68                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
69                             typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~
70                             functionAttributesString!overload ~ ` {`;
71 
72                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
73                             lines ~= "try {";
74 
75                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
76                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~
77                             argNamesParens(arity!overload) ~ `.text;`;
78 
79                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
80                             lines ~= "    } catch(Exception) {}";
81 
82                         lines ~= returnDefault;
83 
84                         lines ~= `}`;
85                         lines ~= "";
86                     }
87                 }
88             }
89         }
90     }
91 
92     return lines.join("\n");
93 }
94 
95 private string argNamesParens(int N) @safe pure {
96     if(!__ctfe) return null;
97     return "(" ~ argNames(N) ~ ")";
98 }
99 
100 private string argNames(int N) @safe pure {
101     import std.range: iota;
102     import std.algorithm: map;
103     import std.array: join;
104     import std.conv: text;
105 
106     if(!__ctfe) return null;
107     return iota(N).map!(a => "arg" ~ a.text).join(", ");
108 }
109 
110 private string typeAndArgsParens(T...)(string prefix) {
111     import std.array;
112     import std.conv;
113     import std.format : format;
114 
115     if(!__ctfe) return null;
116 
117     string[] parts;
118 
119     foreach(i, t; T)
120         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
121     return "(" ~ parts.join(", ") ~ ")";
122 }
123 
124 private string functionAttributesString(alias F)() {
125     import std.traits: functionAttributes, FunctionAttribute;
126     import std.array: join;
127 
128     if(!__ctfe) return null;
129 
130     string[] parts;
131 
132     const attrs = functionAttributes!F;
133 
134     if(attrs & FunctionAttribute.pure_) parts ~= "pure";
135     if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow";
136     if(attrs & FunctionAttribute.trusted) parts ~= "@trusted";
137     if(attrs & FunctionAttribute.safe) parts ~= "@safe";
138     if(attrs & FunctionAttribute.nogc) parts ~= "@nogc";
139     if(attrs & FunctionAttribute.system) parts ~= "@system";
140     // const and immutable can't be done since the mock needs
141     // to alter state
142     // if(attrs & FunctionAttribute.const_) parts ~= "const";
143     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
144     if(attrs & FunctionAttribute.shared_) parts ~= "shared";
145     if(attrs & FunctionAttribute.property) parts ~= "@property";
146 
147     return parts.join(" ");
148 }
149 
150 
151 private mixin template MockImplCommon() {
152     bool _verified;
153     string[] expectedFuncs;
154     string[] calledFuncs;
155     string[] expectedValues;
156     string[] calledValues;
157 
158     void expect(string funcName, V...)(auto ref V values) {
159         import std.conv: text;
160         import std.typecons: tuple;
161 
162         expectedFuncs ~= funcName;
163         static if(V.length > 0)
164             expectedValues ~= tuple(values).text;
165         else
166             expectedValues ~= "";
167     }
168 
169     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
170         expect!func(values);
171         verify(file, line);
172         _verified = false;
173     }
174 
175     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
176         import std.range: repeat, take, join;
177         import std.conv: text;
178         import unit_threaded.exception: fail, UnitTestException;
179 
180         if(_verified)
181             fail("Mock already _verified", file, line);
182 
183         _verified = true;
184 
185         for(int i = 0; i < expectedFuncs.length; ++i) {
186 
187             if(i >= calledFuncs.length)
188                 fail("Expected nth " ~ i.text ~ " call to `" ~ expectedFuncs[i] ~ "` did not happen", file, line);
189 
190             if(expectedFuncs[i] != calledFuncs[i])
191                 fail("Expected nth " ~ i.text ~ " call to `" ~ expectedFuncs[i] ~ "` but got `" ~ calledFuncs[i] ~
192                      "` instead",
193                      file, line);
194 
195             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
196                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
197                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
198                                              "instead of the expected " ~ expectedValues[i]] ,
199                                             file, line);
200         }
201     }
202 }
203 
204 private enum isString(alias T) = is(typeof(T) == string);
205 
206 /**
207    A mock object that conforms to an interface/class.
208  */
209 struct Mock(T) {
210 
211     MockAbstract _impl;
212     alias _impl this;
213 
214     class MockAbstract: T {
215         // needed by implMixinStr
216         import std.conv: text;
217         import std.traits: Parameters, ReturnType;
218         import std.typecons: tuple;
219 
220         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
221         mixin(implMixinStr!T);
222         mixin MockImplCommon;
223     }
224 
225     ///
226     this(int/* force constructor*/) {
227         _impl = new MockAbstract;
228     }
229 
230     ///
231     ~this() pure @safe {
232         if(!_verified) verify;
233     }
234 
235     /// Set the returnValue of a function to certain values.
236     void returnValue(string funcName, V...)(V values) {
237         assertFunctionIsVirtual!funcName;
238         return returnValue!(0, funcName)(values);
239     }
240 
241     /**
242        This version takes overloads into account. i is the overload
243        index. e.g.:
244        ---------
245        interface Interface { void foo(int); void foo(string); }
246        auto m = mock!Interface;
247        m.returnValue!(0, "foo"); // int overload
248        m.returnValue!(1, "foo"); // string overload
249        ---------
250      */
251     void returnValue(int i, string funcName, V...)(V values) {
252         assertFunctionIsVirtual!funcName;
253         import std.conv: text;
254         enum varName = funcName ~ text(`_`, i, `_returnValues`);
255         foreach(v; values)
256             mixin(varName ~ ` ~=  v;`);
257     }
258 
259     private static void assertFunctionIsVirtual(string funcName)() {
260         alias member = Identity!(__traits(getMember, T, funcName));
261 
262         static assert(__traits(isVirtualMethod, member),
263                       "Cannot use returnValue on '" ~ funcName ~ "'");
264     }
265 }
266 
267 
268 private string importsString(string module_, string[] Modules...) {
269     if(!__ctfe) return null;
270 
271     auto ret = `import ` ~ module_ ~ ";\n";
272     foreach(extraModule; Modules) {
273         ret ~= `import ` ~ extraModule ~ ";\n";
274     }
275     return ret;
276 }
277 
278 /// Helper function for creating a Mock object.
279 auto mock(T)() {
280     return Mock!T(0);
281 }
282 
283 ///
284 @("mock interface positive test no params")
285 @safe pure unittest {
286     interface Foo {
287         int foo(int, string) @safe pure;
288         void bar() @safe pure;
289     }
290 
291     int fun(Foo f) {
292         return 2 * f.foo(5, "foobar");
293     }
294 
295     auto m = mock!Foo;
296     m.expect!"foo";
297     fun(m);
298 }
299 
300 
301 ///
302 @("mock interface positive test with params")
303 @safe pure unittest {
304     interface Foo {
305         int foo(int, string) @safe pure;
306         void bar() @safe pure;
307     }
308 
309     int fun(Foo f) {
310         return 2 * f.foo(5, "foobar");
311     }
312 
313     auto m = mock!Foo;
314     m.expect!"foo"(5, "foobar");
315     fun(m);
316 }
317 
318 
319 ///
320 @("interface expectCalled")
321 @safe pure unittest {
322     interface Foo {
323         int foo(int, string) @safe pure;
324         void bar() @safe pure;
325     }
326 
327     int fun(Foo f) {
328         return 2 * f.foo(5, "foobar");
329     }
330 
331     auto m = mock!Foo;
332     fun(m);
333     m.expectCalled!"foo"(5, "foobar");
334 }
335 
336 ///
337 @("interface return value")
338 @safe pure unittest {
339 
340     interface Foo {
341         int timesN(int i) @safe pure;
342     }
343 
344     int fun(Foo f) {
345         return f.timesN(3) * 2;
346     }
347 
348     auto m = mock!Foo;
349     m.returnValue!"timesN"(42);
350     immutable res = fun(m);
351     assert(res == 84);
352 }
353 
354 ///
355 @("interface return values")
356 @safe pure unittest {
357 
358     interface Foo {
359         int timesN(int i) @safe pure;
360     }
361 
362     int fun(Foo f) {
363         return f.timesN(3) * 2;
364     }
365 
366     auto m = mock!Foo;
367     m.returnValue!"timesN"(42, 12);
368     assert(fun(m) == 84);
369     assert(fun(m) == 24);
370     assert(fun(m) == 0);
371 }
372 
373 
374 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T) && T.length > 0) {
375 
376     alias funcName = function_;
377     alias Values = T;
378 
379     static values() {
380         typeof(T[0])[] ret;
381         foreach(val; T) {
382             ret ~= val;
383         }
384         return ret;
385     }
386 }
387 
388 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
389 enum isValue(alias T) = is(typeof(T));
390 
391 
392 /**
393    Version of mockStruct that accepts 0 or more values of the same
394    type. Whatever function is called on it, these values will
395    be returned one by one. The limitation is that if more than one
396    function is called on the mock, they all return the same type
397  */
398 auto mockStruct(T...)(auto ref T returns) if(!from!"std.meta".anySatisfy!(isMockReturn, T)) {
399 
400     static struct Mock {
401 
402         MockImpl* _impl;
403         alias _impl this;
404 
405         static struct MockImpl {
406 
407             static if(T.length > 0) {
408                 alias FirstType = typeof(returns[0]);
409                 private FirstType[] _returnValues;
410             }
411 
412             mixin MockImplCommon;
413 
414             auto opDispatch(string funcName, this This, V...)
415                            (auto ref V values)
416             {
417 
418                 import std.conv: text;
419                 import std.typecons: tuple;
420 
421                 enum isMutable = !is(This == const) && !is(This == immutable);
422 
423                 static if(isMutable) {
424                     calledFuncs ~= funcName;
425                     calledValues ~= tuple(values).text;
426                 }
427 
428                 static if(T.length > 0) {
429 
430                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
431                     auto ret = _returnValues[0];
432                     static if(isMutable)
433                         _returnValues = _returnValues[1..$];
434                     return ret;
435                 }
436             }
437         }
438     }
439 
440     Mock m;
441 
442     // The following line is ugly, but necessary.
443     // If moved to the declaration of impl, it's constructed at compile-time
444     // and only one instance is ever used. Since structs can't have default
445     // constructors, it has to be done here
446     m._impl = new typeof(m).MockImpl;
447     static if(T.length > 0) {
448         foreach(r; returns)
449             m._impl._returnValues ~= r;
450     }
451 
452     return m;
453 }
454 
455 
456 /**
457    Version of mockStruct that accepts a compile-time mapping
458    of function name to return values. Each template parameter
459    must be a value of type `ReturnValues`
460  */
461 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) {
462 
463     static struct Mock {
464         mixin MockImplCommon;
465 
466         int[string] _retIndices;
467 
468         auto opDispatch(string funcName, this This, V...)
469                        (auto ref V values)
470         {
471 
472             import std.conv: text;
473             import std.typecons: tuple;
474 
475             enum isMutable = !is(This == const) && !is(This == immutable);
476 
477             static if(isMutable) {
478                 calledFuncs ~= funcName;
479                 calledValues ~= tuple(values).text;
480             }
481 
482             foreach(retVal; T) {
483                 static if(retVal.funcName == funcName) {
484                     auto ret = retVal.values[_retIndices[funcName]];
485                     static if(isMutable)
486                         ++_retIndices[funcName];
487                     return ret;
488                 }
489             }
490         }
491 
492         auto lefoofoo() {
493             return T[0].values[_retIndices["greet"]++];
494         }
495 
496     }
497 
498     Mock mock;
499 
500     foreach(retVal; T) {
501         mock._retIndices[retVal.funcName] = 0;
502     }
503 
504     return mock;
505 }
506 
507 ///
508 @("mock struct positive")
509 @safe pure unittest {
510     void fun(T)(T t) {
511         t.foobar;
512     }
513     auto m = mockStruct;
514     m.expect!"foobar";
515     fun(m);
516     m.verify;
517 }
518 
519 
520 ///
521 @("mock struct values positive")
522 @safe pure unittest {
523     void fun(T)(T t) {
524         t.foobar(2, "quux");
525     }
526 
527     auto m = mockStruct;
528     m.expect!"foobar"(2, "quux");
529     fun(m);
530     m.verify;
531 }
532 
533 
534 ///
535 @("struct return value")
536 @safe pure unittest {
537 
538     int fun(T)(T f) {
539         return f.timesN(3) * 2;
540     }
541 
542     auto m = mockStruct(42, 12);
543     assert(fun(m) == 84);
544     assert(fun(m) == 24);
545     assert(fun(m) == 0);
546     m.expectCalled!"timesN";
547 }
548 
549 ///
550 @("struct expectCalled")
551 @safe pure unittest {
552     void fun(T)(T t) {
553         t.foobar(2, "quux");
554     }
555 
556     auto m = mockStruct;
557     fun(m);
558     m.expectCalled!"foobar"(2, "quux");
559 }
560 
561 ///
562 @("mockStruct different return types for different functions")
563 @safe pure unittest {
564     auto m = mockStruct!(ReturnValues!("length", 5),
565                          ReturnValues!("greet", "hello"));
566     assert(m.length == 5);
567     assert(m.greet("bar") == "hello");
568     m.expectCalled!"length";
569     m.expectCalled!"greet"("bar");
570 }
571 
572 ///
573 @("mockStruct different return types for different functions and multiple return values")
574 @safe pure unittest {
575     auto m = mockStruct!(ReturnValues!("length", 5, 3),
576                          ReturnValues!("greet", "hello", "g'day"));
577     assert(m.length == 5);
578     m.expectCalled!"length";
579     assert(m.length == 3);
580     m.expectCalled!"length";
581 
582     assert(m.greet("bar") == "hello");
583     m.expectCalled!"greet"("bar");
584     assert(m.greet("quux") == "g'day");
585     m.expectCalled!"greet"("quux");
586 }
587 
588 
589 /**
590    A mock struct that always throws.
591  */
592 auto throwStruct(E = from!"unit_threaded.exception".UnitTestException, R = void)() {
593 
594     struct Mock {
595 
596         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
597                     (auto ref V values) {
598             throw new E(funcName ~ " was called", file, line);
599         }
600     }
601 
602     return Mock();
603 }
604 
605 ///
606 @("throwStruct default")
607 @safe pure unittest {
608     import std.exception: assertThrown;
609     import unit_threaded.exception: UnitTestException;
610     auto m = throwStruct;
611     assertThrown!UnitTestException(m.foo);
612     assertThrown!UnitTestException(m.bar(1, "foo"));
613 }
614 
615 
616 auto mockStruct(R...)(auto ref R returns) if(R.length > 0 && from!"std.meta".allSatisfy!(isMockReturn, R)) {
617 
618     struct Mock {
619 
620         mixin MockImplCommon;
621 
622         int[string] _retIndices;
623 
624         auto opDispatch(string funcName, this This, V...)
625                        (auto ref V values)
626         {
627 
628             import std.conv: text;
629             import std.typecons: tuple;
630 
631             enum isMutable = !is(This == const) && !is(This == immutable);
632 
633             static if(isMutable) {
634                 calledFuncs ~= funcName;
635                 calledValues ~= tuple(values).text;
636             }
637 
638             static foreach(i, returnType; R) {
639                 static if(returnType.Name == funcName) {
640                     auto ret = returns[i].values[_retIndices[funcName]];
641 
642                     static if(isMutable)
643                         ++_retIndices[funcName];
644 
645                     return ret;
646                 }
647             }
648 
649             assert(0, "No return value for `" ~ funcName ~ "`");
650         }
651     }
652 
653     Mock mock;
654 
655     static foreach(returnType; R) {
656         mock._retIndices[returnType.Name] = 0;
657     }
658 
659     return mock;
660 
661 }
662 
663 auto mockReturn(string name, V...)(auto ref V values) {
664     return MockReturn!(name, V[0])(values);
665 }
666 
667 template allSameType(V...) {
668     import std.meta: allSatisfy;
669     enum isSameAsFirst(T) = is(T == V);
670     enum allSameType = allSatisfy!(isSameAsFirst, V);
671 }
672 
673 
674 private struct MockReturn(string funcName, V) {
675 
676     alias Name = funcName;
677     V[] values;
678 
679     this(A...)(auto ref A args) {
680         foreach(arg; args) values ~= arg;
681     }
682 }
683 
684 enum isMockReturn(T) = is(T == MockReturn!(name, V), string name, V);
685 static assert(isMockReturn!(typeof(mockReturn!"length"(42))));