1 /**
2    Support the automatic implementation of test doubles via programmable mocks.
3  */
4 module unit_threaded.mock;
5 
6 import unit_threaded.from;
7 
8 alias Identity(alias T) = T;
9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
10 
11 
12 string implMixinStr(T)() {
13     import std.array: join;
14     import std.format : format;
15     import std.range : iota;
16     import std.traits: functionAttributes, FunctionAttribute, Parameters, ReturnType, arity;
17     import std.conv: text;
18 
19     if(!__ctfe) return null;
20 
21     string[] lines;
22 
23     string getOverload(in string memberName, in int i) {
24         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
25             .format(memberName, i);
26     }
27 
28     foreach(memberName; __traits(allMembers, T)) {
29 
30         static if(!isPrivate!(T, memberName)) {
31 
32             alias member = Identity!(__traits(getMember, T, memberName));
33 
34             static if(__traits(isVirtualMethod, member)) {
35                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
36 
37                     static if(!(functionAttributes!overload & FunctionAttribute.const_) &&
38                               !(functionAttributes!overload & FunctionAttribute.const_)) {
39 
40                         enum overloadName = text(memberName, "_", i);
41 
42                         enum overloadString = getOverload(memberName, i);
43                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(
44                             overloadName, overloadString);
45                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(
46                             overloadName, overloadString);
47 
48                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
49                             enum tryIndent = "    ";
50                         else
51                             enum tryIndent = "";
52 
53                         static if(is(ReturnType!overload == void))
54                             enum returnDefault = "";
55                         else {
56                             enum varName = overloadName ~ `_returnValues`;
57                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
58                             lines ~= "";
59                             enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
60                                                   `        auto ret = ` ~ varName ~ `[0];`,
61                                                   `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
62                                                   `        return ret;`,
63                                                   `    } else`,
64                                                   `        return %s_returnType.init;`.format(
65                                                       overloadName)];
66                         }
67 
68                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
69                             typeAndArgsParens!(Parameters!overload)(overloadName) ~ " " ~
70                             functionAttributesString!overload ~ ` {`;
71 
72                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
73                             lines ~= "try {";
74 
75                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
76                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~
77                             argNamesParens(arity!overload) ~ `.to!string;`;
78 
79                         static if(functionAttributes!overload & FunctionAttribute.nothrow_)
80                             lines ~= "    } catch(Exception) {}";
81 
82                         lines ~= returnDefault;
83 
84                         lines ~= `}`;
85                         lines ~= "";
86                     }
87                 }
88             }
89         }
90     }
91 
92     return lines.join("\n");
93 }
94 
95 private string argNamesParens(int N) @safe pure {
96     if(!__ctfe) return null;
97     return "(" ~ argNames(N) ~ ")";
98 }
99 
100 private string argNames(int N) @safe pure {
101     import std.range;
102     import std.algorithm;
103     import std.conv;
104 
105     if(!__ctfe) return null;
106     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
107 }
108 
109 private string typeAndArgsParens(T...)(string prefix) {
110     import std.array;
111     import std.conv;
112     import std.format : format;
113 
114     if(!__ctfe) return null;
115 
116     string[] parts;
117 
118     foreach(i, t; T)
119         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
120     return "(" ~ parts.join(", ") ~ ")";
121 }
122 
123 private string functionAttributesString(alias F)() {
124     import std.traits: functionAttributes, FunctionAttribute;
125     import std.array: join;
126 
127     if(!__ctfe) return null;
128 
129     string[] parts;
130 
131     const attrs = functionAttributes!F;
132 
133     if(attrs & FunctionAttribute.pure_) parts ~= "pure";
134     if(attrs & FunctionAttribute.nothrow_) parts ~= "nothrow";
135     if(attrs & FunctionAttribute.trusted) parts ~= "@trusted";
136     if(attrs & FunctionAttribute.safe) parts ~= "@safe";
137     if(attrs & FunctionAttribute.nogc) parts ~= "@nogc";
138     if(attrs & FunctionAttribute.system) parts ~= "@system";
139     // const and immutable can't be done since the mock needs
140     // to alter state
141     // if(attrs & FunctionAttribute.const_) parts ~= "const";
142     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
143     if(attrs & FunctionAttribute.shared_) parts ~= "shared";
144     if(attrs & FunctionAttribute.property) parts ~= "@property";
145 
146     return parts.join(" ");
147 }
148 
149 mixin template MockImplCommon() {
150     bool _verified;
151     string[] expectedFuncs;
152     string[] calledFuncs;
153     string[] expectedValues;
154     string[] calledValues;
155 
156     void expect(string funcName, V...)(auto ref V values) {
157         import std.conv: to;
158         import std.typecons: tuple;
159 
160         expectedFuncs ~= funcName;
161         static if(V.length > 0)
162             expectedValues ~= tuple(values).to!string;
163         else
164             expectedValues ~= "";
165     }
166 
167     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
168         expect!func(values);
169         verify(file, line);
170         _verified = false;
171     }
172 
173     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
174         import std.range: repeat, take, join;
175         import std.conv: to;
176         import unit_threaded.exception: fail, UnitTestException;
177 
178         if(_verified)
179             fail("Mock already _verified", file, line);
180 
181         _verified = true;
182 
183         for(int i = 0; i < expectedFuncs.length; ++i) {
184 
185             if(i >= calledFuncs.length)
186                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
187 
188             if(expectedFuncs[i] != calledFuncs[i])
189                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
190                      " instead",
191                      file, line);
192 
193             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
194                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
195                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
196                                              "instead of the expected " ~ expectedValues[i]] ,
197                                             file, line);
198         }
199     }
200 }
201 
202 private enum isString(alias T) = is(typeof(T) == string);
203 
204 /**
205    A mock object that conforms to an interface/class.
206  */
207 struct Mock(T) {
208 
209     MockAbstract _impl;
210     alias _impl this;
211 
212     class MockAbstract: T {
213         import std.conv: to;
214         import std.traits: Parameters, ReturnType;
215         import std.typecons: tuple;
216 
217         //static if(__traits(identifier, T) == "foobarbaz")
218         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
219         mixin(implMixinStr!T);
220         mixin MockImplCommon;
221     }
222 
223     ///
224     this(int/* force constructor*/) {
225         _impl = new MockAbstract;
226     }
227 
228     ///
229     ~this() pure @safe {
230         if(!_verified) verify;
231     }
232 
233     /// Set the returnValue of a function to certain values.
234     void returnValue(string funcName, V...)(V values) {
235         assertFunctionIsVirtual!funcName;
236         return returnValue!(0, funcName)(values);
237     }
238 
239     /**
240        This version takes overloads into account. i is the overload
241        index. e.g.:
242        ---------
243        interface Interface { void foo(int); void foo(string); }
244        auto m = mock!Interface;
245        m.returnValue!(0, "foo"); // int overload
246        m.returnValue!(1, "foo"); // string overload
247        ---------
248      */
249     void returnValue(int i, string funcName, V...)(V values) {
250         assertFunctionIsVirtual!funcName;
251         import std.conv: text;
252         enum varName = funcName ~ text(`_`, i, `_returnValues`);
253         foreach(v; values)
254             mixin(varName ~ ` ~=  v;`);
255     }
256 
257     private static void assertFunctionIsVirtual(string funcName)() {
258         alias member = Identity!(__traits(getMember, T, funcName));
259 
260         static assert(__traits(isVirtualMethod, member),
261                       "Cannot use returnValue on '" ~ funcName ~ "'");
262     }
263 }
264 
265 private string importsString(string module_, string[] Modules...) {
266     if(!__ctfe) return null;
267 
268     auto ret = `import ` ~ module_ ~ ";\n";
269     foreach(extraModule; Modules) {
270         ret ~= `import ` ~ extraModule ~ ";\n";
271     }
272     return ret;
273 }
274 
275 /// Helper function for creating a Mock object.
276 auto mock(T)() {
277     return Mock!T(0);
278 }
279 
280 ///
281 @("mock interface positive test no params")
282 @safe pure unittest {
283     interface Foo {
284         int foo(int, string) @safe pure;
285         void bar() @safe pure;
286     }
287 
288     int fun(Foo f) {
289         return 2 * f.foo(5, "foobar");
290     }
291 
292     auto m = mock!Foo;
293     m.expect!"foo";
294     fun(m);
295 }
296 
297 
298 ///
299 @("mock interface positive test with params")
300 @safe pure unittest {
301     interface Foo {
302         int foo(int, string) @safe pure;
303         void bar() @safe pure;
304     }
305 
306     int fun(Foo f) {
307         return 2 * f.foo(5, "foobar");
308     }
309 
310     auto m = mock!Foo;
311     m.expect!"foo"(5, "foobar");
312     fun(m);
313 }
314 
315 
316 ///
317 @("interface expectCalled")
318 @safe pure unittest {
319     interface Foo {
320         int foo(int, string) @safe pure;
321         void bar() @safe pure;
322     }
323 
324     int fun(Foo f) {
325         return 2 * f.foo(5, "foobar");
326     }
327 
328     auto m = mock!Foo;
329     fun(m);
330     m.expectCalled!"foo"(5, "foobar");
331 }
332 
333 ///
334 @("interface return value")
335 @safe pure unittest {
336 
337     interface Foo {
338         int timesN(int i) @safe pure;
339     }
340 
341     int fun(Foo f) {
342         return f.timesN(3) * 2;
343     }
344 
345     auto m = mock!Foo;
346     m.returnValue!"timesN"(42);
347     immutable res = fun(m);
348     assert(res == 84);
349 }
350 
351 ///
352 @("interface return values")
353 @safe pure unittest {
354 
355     interface Foo {
356         int timesN(int i) @safe pure;
357     }
358 
359     int fun(Foo f) {
360         return f.timesN(3) * 2;
361     }
362 
363     auto m = mock!Foo;
364     m.returnValue!"timesN"(42, 12);
365     assert(fun(m) == 84);
366     assert(fun(m) == 24);
367     assert(fun(m) == 0);
368 }
369 
370 struct ReturnValues(string function_, T...) if(from!"std.meta".allSatisfy!(isValue, T)) {
371     alias funcName = function_;
372     alias Values = T;
373 
374     static auto values() {
375         typeof(T[0])[] ret;
376         foreach(val; T) {
377             ret ~= val;
378         }
379         return ret;
380     }
381 }
382 
383 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
384 enum isValue(alias T) = is(typeof(T));
385 
386 
387 /**
388    Version of mockStruct that accepts 0 or more values of the same
389    type. Whatever function is called on it, these values will
390    be returned one by one. The limitation is that if more than one
391    function is called on the mock, they all return the same type
392  */
393 auto mockStruct(T...)(auto ref T returns) {
394 
395     struct Mock {
396 
397         MockImpl* _impl;
398         alias _impl this;
399 
400         static struct MockImpl {
401 
402             static if(T.length > 0) {
403                 alias FirstType = typeof(returns[0]);
404                 private FirstType[] _returnValues;
405             }
406 
407             mixin MockImplCommon;
408 
409             auto opDispatch(string funcName, this This, V...)
410                            (auto ref V values)
411             {
412 
413                 import std.conv: to;
414                 import std.typecons: tuple;
415 
416                 enum isMutable = !is(This == const) && !is(This == immutable);
417 
418                 static if(isMutable) {
419                     calledFuncs ~= funcName;
420                     calledValues ~= tuple(values).to!string;
421                 }
422 
423                 static if(T.length > 0) {
424 
425                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
426                     auto ret = _returnValues[0];
427                     static if(isMutable)
428                         _returnValues = _returnValues[1..$];
429                     return ret;
430                 }
431             }
432         }
433     }
434 
435     Mock m;
436     m._impl = new Mock.MockImpl;
437     static if(T.length > 0) {
438         foreach(r; returns)
439             m._impl._returnValues ~= r;
440     }
441 
442     return m;
443 }
444 
445 /**
446    Version of mockStruct that accepts a compile-time mapping
447    of function name to return values. Each template parameter
448    must be a value of type `ReturnValues`
449  */
450 auto mockStruct(T...)() if(T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) {
451 
452     struct Mock {
453         mixin MockImplCommon;
454 
455         int[string] _retIndices;
456 
457         auto opDispatch(string funcName, this This, V...)
458                        (auto ref V values)
459         {
460 
461             import std.conv: text;
462             import std.typecons: tuple;
463 
464             enum isMutable = !is(This == const) && !is(This == immutable);
465 
466             static if(isMutable) {
467                 calledFuncs ~= funcName;
468                 calledValues ~= tuple(values).text;
469             }
470 
471             foreach(retVal; T) {
472                 static if(retVal.funcName == funcName) {
473                     auto ret = retVal.values[_retIndices[funcName]];
474                     static if(isMutable)
475                         ++_retIndices[funcName];
476                     return ret;
477                 }
478             }
479         }
480 
481         auto lefoofoo() {
482             return T[0].values[_retIndices["greet"]++];
483         }
484 
485     }
486 
487     Mock mock;
488 
489     foreach(retVal; T) {
490         mock._retIndices[retVal.funcName] = 0;
491     }
492 
493     return mock;
494 }
495 
496 ///
497 @("mock struct positive")
498 @safe pure unittest {
499     void fun(T)(T t) {
500         t.foobar;
501     }
502     auto m = mockStruct;
503     m.expect!"foobar";
504     fun(m);
505     m.verify;
506 }
507 
508 
509 ///
510 @("mock struct values positive")
511 @safe pure unittest {
512     void fun(T)(T t) {
513         t.foobar(2, "quux");
514     }
515 
516     auto m = mockStruct;
517     m.expect!"foobar"(2, "quux");
518     fun(m);
519     m.verify;
520 }
521 
522 
523 ///
524 @("struct return value")
525 @safe pure unittest {
526 
527     int fun(T)(T f) {
528         return f.timesN(3) * 2;
529     }
530 
531     auto m = mockStruct(42, 12);
532     assert(fun(m) == 84);
533     assert(fun(m) == 24);
534     assert(fun(m) == 0);
535     m.expectCalled!"timesN";
536 }
537 
538 ///
539 @("struct expectCalled")
540 @safe pure unittest {
541     void fun(T)(T t) {
542         t.foobar(2, "quux");
543     }
544 
545     auto m = mockStruct;
546     fun(m);
547     m.expectCalled!"foobar"(2, "quux");
548 }
549 
550 ///
551 @("mockStruct different return types for different functions")
552 @safe pure unittest {
553     auto m = mockStruct!(ReturnValues!("length", 5),
554                          ReturnValues!("greet", "hello"));
555     assert(m.length == 5);
556     assert(m.greet("bar") == "hello");
557     m.expectCalled!"length";
558     m.expectCalled!"greet"("bar");
559 }
560 
561 ///
562 @("mockStruct different return types for different functions and multiple return values")
563 @safe pure unittest {
564     auto m = mockStruct!(ReturnValues!("length", 5, 3),
565                          ReturnValues!("greet", "hello", "g'day"));
566     assert(m.length == 5);
567     m.expectCalled!"length";
568     assert(m.length == 3);
569     m.expectCalled!"length";
570 
571     assert(m.greet("bar") == "hello");
572     m.expectCalled!"greet"("bar");
573     assert(m.greet("quux") == "g'day");
574     m.expectCalled!"greet"("quux");
575 }
576 
577 
578 /**
579    A mock struct that always throws.
580  */
581 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() {
582 
583     struct Mock {
584 
585         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
586                     (auto ref V values) {
587             throw new E(funcName ~ " was called", file, line);
588         }
589     }
590 
591     return Mock();
592 }
593 
594 ///
595 @("throwStruct default")
596 @safe pure unittest {
597     import std.exception: assertThrown;
598     import unit_threaded.exception: UnitTestException;
599     auto m = throwStruct;
600     assertThrown!UnitTestException(m.foo);
601     assertThrown!UnitTestException(m.bar(1, "foo"));
602 }