1 /**
2    Support the automatic implementation of test doubles via programmable mocks.
3  */
4 module unit_threaded.mock;
5 
6 import unit_threaded.from;
7 
8 alias Identity(alias T) = T;
9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
10 
11 
12 string implMixinStr(T)() {
13     import std.array: join;
14     import std.format : format;
15     import std.range : iota;
16     import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity;
17     import std.conv: text;
18 
19     if(!__ctfe) return null;
20 
21     string[] lines;
22 
23     string getOverload(in string memberName, in int i) {
24         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
25             .format(memberName, i);
26     }
27 
28     foreach(memberName; __traits(allMembers, T)) {
29 
30         static if(!isPrivate!(T, memberName)) {
31 
32             alias member = Identity!(__traits(getMember, T, memberName));
33 
34             static if(__traits(isVirtualMethod, member)) {
35                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
36 
37                     static if(!(functionAttributes!overload & FunctionAttribute.const_) &&
38                               !(functionAttributes!overload & FunctionAttribute.const_)) {
39 
40                         enum overloadName = text(memberName, "_", i);
41 
42                         enum overloadString = getOverload(memberName, i);
43                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(
44                             overloadName, overloadString);
45                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(
46                             overloadName, overloadString);
47 
48                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
49                             enum tryIndent = "    ";
50                         else
51                             enum tryIndent = "";
52 
53                         static if(is(ReturnType!overload == void))
54                             enum returnDefault = "";
55                         else {
56                             enum varName = overloadName ~ `_returnValues`;
57                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
58                             lines ~= "";
59                             enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
60                                                   `        auto ret = ` ~ varName ~ `[0];`,
61                                                   `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
62                                                   `        return ret;`,
63                                                   `    } else`,
64                                                   `        return %s_returnType.init;`.format(
65                                                       overloadName)];
66                         }
67 
68                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
69                             typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~
70                             functionAttributesString!overload ~ ` {`;
71 
72                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
73                             lines ~= "try {";
74 
75                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
76                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~
77                             argNamesParens(arity!overload) ~ `.to!string;`;
78 
79                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
80                             lines ~= "    } catch(Exception) {}";
81 
82                         lines ~= returnDefault;
83 
84                         lines ~= `}`;
85                         lines ~= "";
86                     }
87                 }
88             }
89         }
90     }
91 
92     return lines.join("\n");
93 }
94 
95 private string argNamesParens(int N) @safe pure {
96     if(!__ctfe) return null;
97     return "(" ~ argNames(N) ~ ")";
98 }
99 
100 private string argNames(int N) @safe pure {
101     import std.range;
102     import std.algorithm;
103     import std.conv;
104 
105     if(!__ctfe) return null;
106     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
107 }
108 
109 private string typeAndArgsParens(T...)(string prefix) {
110     import std.array;
111     import std.conv;
112     import std.format : format;
113 
114     if(!__ctfe) return null;
115 
116     string[] parts;
117 
118     foreach(i, t; T)
119         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
120     return "(" ~ parts.join(", ") ~ ")";
121 }
122 
123 private string functionAttributesString(alias F)() {
124     import std.traits: functionAttributes, FunctionAttribute;
125     import std.array: join;
126 
127     if(!__ctfe) return null;
128 
129     string[] parts;
130 
131     const attrs = functionAttributes!F;
132 
133     if(attrs & FunctionAttribute.pure_) parts ~= "pure";
134     if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow";
135     if(attrs & FunctionAttribute.trusted) parts ~= "@trusted";
136     if(attrs & FunctionAttribute.safe) parts ~= "@safe";
137     if(attrs & FunctionAttribute.nogc) parts ~= "@nogc";
138     if(attrs & FunctionAttribute.system) parts ~= "@system";
139     // const and immutable can't be done since the mock needs
140     // to alter state
141     // if(attrs & FunctionAttribute.const_) parts ~= "const";
142     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
143     if(attrs & FunctionAttribute.shared_) parts ~= "shared";
144     if(attrs & FunctionAttribute.property) parts ~= "@property";
145 
146     return parts.join(" ");
147 }
148 
149 mixin template MockImplCommon() {
150     bool _verified;
151     string[] expectedFuncs;
152     string[] calledFuncs;
153     string[] expectedValues;
154     string[] calledValues;
155 
156     void expect(string funcName, V...)(auto ref V values) {
157         import std.conv: to;
158         import std.typecons: tuple;
159 
160         expectedFuncs ~= funcName;
161         static if(V.length > 0)
162             expectedValues ~= tuple(values).to!string;
163         else
164             expectedValues ~= "";
165     }
166 
167     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
168         expect!func(values);
169         verify(file, line);
170         _verified = false;
171     }
172 
173     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
174         import std.range: repeat, take, join;
175         import std.conv: to;
176         import unit_threaded.should: fail, UnitTestException;
177 
178         if(_verified)
179             fail("Mock already _verified", file, line);
180 
181         _verified = true;
182 
183         for(int i = 0; i < expectedFuncs.length; ++i) {
184 
185             if(i >= calledFuncs.length)
186                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
187 
188             if(expectedFuncs[i] != calledFuncs[i])
189                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
190                      " instead",
191                      file, line);
192 
193             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
194                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
195                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
196                                              "instead of the expected " ~ expectedValues[i]] ,
197                                             file, line);
198         }
199     }
200 }
201 
202 private enum isString(alias T) = is(typeof(T) == string);
203 
204 /**
205    A mock object that conforms to an interface/class.
206  */
207 struct Mock(T) {
208 
209     MockAbstract _impl;
210     alias _impl this;
211 
212     class MockAbstract: T {
213         import std.conv: to;
214         import std.traits: Parameters, ReturnType;
215         import std.typecons: tuple;
216 
217         //static if(__traits(identifier, T) == "foobarbaz")
218         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
219         mixin(implMixinStr!T);
220         mixin MockImplCommon;
221     }
222 
223     ///
224     this(int/* force constructor*/) {
225         _impl = new MockAbstract;
226     }
227 
228     ///
229     ~this() pure @safe {
230         if(!_verified) verify;
231     }
232 
233     /// Set the returnValue of a function to certain values.
234     void returnValue(string funcName, V...)(V values) {
235         assertFunctionIsVirtual!funcName;
236         return returnValue!(0, funcName)(values);
237     }
238 
239     /**
240        This version takes overloads into account. i is the overload
241        index. e.g.:
242        ---------
243        interface Interface { void foo(int); void foo(string); }
244        auto m = mock!Interface;
245        m.returnValue!(0, "foo"); // int overload
246        m.returnValue!(1, "foo"); // string overload
247        ---------
248      */
249     void returnValue(int i, string funcName, V...)(V values) {
250         assertFunctionIsVirtual!funcName;
251         import std.conv: text;
252         enum varName = funcName ~ text(`_`, i, `_returnValues`);
253         foreach(v; values)
254             mixin(varName ~ ` ~=  v;`);
255     }
256 
257     private static void assertFunctionIsVirtual(string funcName)() {
258         alias member = Identity!(__traits(getMember, T, funcName));
259 
260         static assert(__traits(isVirtualMethod, member),
261                       "Cannot use returnValue on '" ~ funcName ~ "'");
262     }
263 }
264 
265 private string importsString(string module_, string[] Modules...) {
266     if(!__ctfe) return null;
267 
268     auto ret = `import ` ~ module_ ~ ";\n";
269     foreach(extraModule; Modules) {
270         ret ~= `import ` ~ extraModule ~ ";\n";
271     }
272     return ret;
273 }
274 
275 /// Helper function for creating a Mock object.
276 auto mock(T)() {
277     return Mock!T(0);
278 }
279 
280 ///
281 @("mock interface positive test no params")
282 @safe pure unittest {
283     interface Foo {
284         int foo(int, string) @safe pure;
285         void bar() @safe pure;
286     }
287 
288     int fun(Foo f) {
289         return 2 * f.foo(5, "foobar");
290     }
291 
292     auto m = mock!Foo;
293     m.expect!"foo";
294     fun(m);
295 }
296 
297 
298 ///
299 @("mock interface positive test with params")
300 @safe pure unittest {
301     import unit_threaded.asserts;
302 
303     interface Foo {
304         int foo(int, string) @safe pure;
305         void bar() @safe pure;
306     }
307 
308     int fun(Foo f) {
309         return 2 * f.foo(5, "foobar");
310     }
311 
312     auto m = mock!Foo;
313     m.expect!"foo"(5, "foobar");
314     fun(m);
315 }
316 
317 
318 ///
319 @("interface expectCalled")
320 @safe pure unittest {
321     interface Foo {
322         int foo(int, string) @safe pure;
323         void bar() @safe pure;
324     }
325 
326     int fun(Foo f) {
327         return 2 * f.foo(5, "foobar");
328     }
329 
330     auto m = mock!Foo;
331     fun(m);
332     m.expectCalled!"foo"(5, "foobar");
333 }
334 
335 ///
336 @("interface return value")
337 @safe pure unittest {
338 
339     interface Foo {
340         int timesN(int i) @safe pure;
341     }
342 
343     int fun(Foo f) {
344         return f.timesN(3) * 2;
345     }
346 
347     auto m = mock!Foo;
348     m.returnValue!"timesN"(42);
349     immutable res = fun(m);
350     assert(res == 84);
351 }
352 
353 ///
354 @("interface return values")
355 @safe pure unittest {
356 
357     interface Foo {
358         int timesN(int i) @safe pure;
359     }
360 
361     int fun(Foo f) {
362         return f.timesN(3) * 2;
363     }
364 
365     auto m = mock!Foo;
366     m.returnValue!"timesN"(42, 12);
367     assert(fun(m) == 84);
368     assert(fun(m) == 24);
369     assert(fun(m) == 0);
370 }
371 
372 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T)) {
373     alias funcName = function_;
374     alias Values = T;
375 
376     static auto values() {
377         typeof(T[0])[] ret;
378         foreach(val; T) {
379             ret ~= val;
380         }
381         return ret;
382     }
383 }
384 
385 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
386 enum isValue(alias T) = is(typeof(T));
387 
388 
389 /**
390    Version of mockStruct that accepts 0 or more values of the same
391    type. Whatever function is called on it, these values will
392    be returned one by one. The limitation is that if more than one
393    function is called on the mock, they all return the same type
394  */
395 auto mockStruct(T...)(auto ref T returns) {
396 
397     struct Mock {
398 
399         MockImpl* _impl;
400         alias _impl this;
401 
402         static struct MockImpl {
403 
404             static if(T.length > 0) {
405                 alias FirstType = typeof(returns[0]);
406                 private FirstType[] _returnValues;
407             }
408 
409             mixin MockImplCommon;
410 
411             auto opDispatch(string funcName, V...)(auto ref V values) {
412 
413                 import std.conv: to;
414                 import std.typecons: tuple;
415 
416                 calledFuncs ~= funcName;
417                 calledValues ~= tuple(values).to!string;
418 
419                 static if(T.length > 0) {
420 
421                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
422                     auto ret = _returnValues[0];
423                     _returnValues = _returnValues[1..$];
424                     return ret;
425                 }
426             }
427         }
428     }
429 
430     Mock m;
431     m._impl = new Mock.MockImpl;
432     static if(T.length > 0) {
433         foreach(r; returns)
434             m._impl._returnValues ~= r;
435     }
436 
437     return m;
438 }
439 
440 /**
441    Version of mockStruct that accepts a compile-time mapping
442    of function name to return values. Each template parameter
443    must be a value of type `ReturnValues`
444  */
445 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) {
446 
447     struct Mock {
448         mixin MockImplCommon;
449 
450         int[string] _retIndices;
451 
452         auto opDispatch(string funcName, V...)(auto ref V values) {
453 
454             import std.conv: to;
455             import std.typecons: tuple;
456 
457             calledFuncs ~= funcName;
458             calledValues ~= tuple(values).to!string;
459 
460             foreach(retVal; T) {
461                 static if(retVal.funcName == funcName) {
462                     return retVal.values[_retIndices[funcName]++];
463                 }
464             }
465         }
466 
467         auto lefoofoo() {
468             return T[0].values[_retIndices["greet"]++];
469         }
470 
471     }
472 
473     Mock mock;
474 
475     foreach(retVal; T) {
476         mock._retIndices[retVal.funcName] = 0;
477     }
478 
479     return mock;
480 }
481 
482 ///
483 @("mock struct positive")
484 @safe pure unittest {
485     void fun(T)(T t) {
486         t.foobar;
487     }
488     auto m = mockStruct;
489     m.expect!"foobar";
490     fun(m);
491     m.verify;
492 }
493 
494 
495 ///
496 @("mock struct values positive")
497 @safe pure unittest {
498     void fun(T)(T t) {
499         t.foobar(2, "quux");
500     }
501 
502     auto m = mockStruct;
503     m.expect!"foobar"(2, "quux");
504     fun(m);
505     m.verify;
506 }
507 
508 
509 ///
510 @("struct return value")
511 @safe pure unittest {
512 
513     int fun(T)(T f) {
514         return f.timesN(3) * 2;
515     }
516 
517     auto m = mockStruct(42, 12);
518     assert(fun(m) == 84);
519     assert(fun(m) == 24);
520     assert(fun(m) == 0);
521     m.expectCalled!"timesN";
522 }
523 
524 ///
525 @("struct expectCalled")
526 @safe pure unittest {
527     void fun(T)(T t) {
528         t.foobar(2, "quux");
529     }
530 
531     auto m = mockStruct;
532     fun(m);
533     m.expectCalled!"foobar"(2, "quux");
534 }
535 
536 ///
537 @("mockStruct different return types for different functions")
538 @safe pure unittest {
539     auto m = mockStruct!(ReturnValues!("length", 5),
540                          ReturnValues!("greet", "hello"));
541     assert(m.length == 5);
542     assert(m.greet("bar") == "hello");
543     m.expectCalled!"length";
544     m.expectCalled!"greet"("bar");
545 }
546 
547 ///
548 @("mockStruct different return types for different functions and multiple return values")
549 @safe pure unittest {
550     auto m = mockStruct!(ReturnValues!("length", 5, 3),
551                          ReturnValues!("greet", "hello", "g'day"));
552     assert(m.length == 5);
553     m.expectCalled!"length";
554     assert(m.length == 3);
555     m.expectCalled!"length";
556 
557     assert(m.greet("bar") == "hello");
558     m.expectCalled!"greet"("bar");
559     assert(m.greet("quux") == "g'day");
560     m.expectCalled!"greet"("quux");
561 }
562 
563 
564 /**
565    A mock struct that always throws.
566  */
567 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() {
568 
569     struct Mock {
570 
571         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
572                     (auto ref V values) {
573             throw new E(funcName ~ " was called", file, line);
574         }
575     }
576 
577     return Mock();
578 }
579 
580 ///
581 @("throwStruct default")
582 @safe pure unittest {
583     import std.exception: assertThrown;
584     import unit_threaded.should: UnitTestException;
585     auto m = throwStruct;
586     assertThrown!UnitTestException(m.foo);
587     assertThrown!UnitTestException(m.bar(1, "foo"));
588 }