1 module unit_threaded.mock;
2 
3 import std.traits;
4 import std.traits: allSameType, allSatisfy;
5 import unit_threaded.should: UnitTestException;
6 
7 alias Identity(alias T) = T;
8 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
9 
10 
11 string implMixinStr(T)() {
12     import std.array: join;
13     import std.format : format;
14     import std.range : iota;
15     import std.traits: functionAttributes, FunctionAttribute, Parameters, arity;
16     import std.conv: text;
17 
18     string[] lines;
19 
20     string getOverload(in string memberName, in int i) {
21         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
22             .format(memberName, i);
23     }
24 
25     foreach(memberName; __traits(allMembers, T)) {
26 
27         static if(!isPrivate!(T, memberName)) {
28 
29             alias member = Identity!(__traits(getMember, T, memberName));
30 
31             static if(__traits(isAbstractFunction, member)) {
32                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
33 
34                     enum overloadName = text(memberName, "_", i);
35 
36                     enum overloadString = getOverload(memberName, i);
37                     lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, overloadString);
38                     lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, overloadString);
39 
40                     static if(functionAttributes!member & FunctionAttribute.nothrow_)
41                         enum tryIndent = "    ";
42                     else
43                         enum tryIndent = "";
44 
45                     static if(is(ReturnType!member == void))
46                         enum returnDefault = "";
47                     else {
48                         enum varName = overloadName ~ `_returnValues`;
49                         lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
50                         lines ~= "";
51                         enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
52                                               `        auto ret = ` ~ varName ~ `[0];`,
53                                               `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
54                                               `        return ret;`,
55                                               `    } else`,
56                                               `        return %s_returnType.init;`.format(overloadName)];
57                     }
58 
59                     lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
60                         typeAndArgsParens!(Parameters!member)(overloadName) ~ ` {`;
61 
62                     static if(functionAttributes!member & FunctionAttribute.nothrow_)
63                         lines ~= "try {";
64 
65                     lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
66                     lines ~= tryIndent ~ `    calledValues ~= tuple` ~ argNamesParens(arity!member) ~ `.to!string;`;
67 
68                     static if(functionAttributes!member & FunctionAttribute.nothrow_)
69                         lines ~= "    } catch(Exception) {}";
70 
71                     lines ~= returnDefault;
72 
73                     lines ~= `}`;
74                     lines ~= "";
75                 }
76             }
77         }
78     }
79 
80     return lines.join("\n");
81 }
82 
83 private string argNamesParens(int N) @safe pure {
84     return "(" ~ argNames(N) ~ ")";
85 }
86 
87 private string argNames(int N) @safe pure {
88     import std.range;
89     import std.algorithm;
90     import std.conv;
91     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
92 }
93 
94 private string typeAndArgsParens(T...)(string prefix) {
95     import std.array;
96     import std.conv;
97     import std.format : format;
98     string[] parts;
99     foreach(i, t; T)
100         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
101     return "(" ~ parts.join(", ") ~ ")";
102 }
103 
104 mixin template MockImplCommon() {
105     bool _verified;
106     string[] expectedFuncs;
107     string[] calledFuncs;
108     string[] expectedValues;
109     string[] calledValues;
110 
111     void expect(string funcName, V...)(auto ref V values) @safe pure {
112         import std.conv: to;
113         import std.typecons: tuple;
114 
115         expectedFuncs ~= funcName;
116         static if(V.length > 0)
117             expectedValues ~= tuple(values).to!string;
118         else
119             expectedValues ~= "";
120     }
121 
122     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
123         expect!func(values);
124         verify(file, line);
125         _verified = false;
126     }
127 
128     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
129         import std.range: repeat, take, join;
130         import std.conv: to;
131         import unit_threaded.should: fail, UnitTestException;
132 
133         if(_verified)
134             fail("Mock already _verified", file, line);
135 
136         _verified = true;
137 
138         for(int i = 0; i < expectedFuncs.length; ++i) {
139 
140             if(i >= calledFuncs.length)
141                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
142 
143             if(expectedFuncs[i] != calledFuncs[i])
144                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
145                      " instead",
146                      file, line);
147 
148             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
149                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
150                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
151                                              "instead of the expected " ~ expectedValues[i]] ,
152                                             file, line);
153         }
154     }
155 }
156 
157 private enum isString(alias T) = is(typeof(T) == string);
158 
159 struct Mock(T) {
160 
161     MockAbstract _impl;
162     alias _impl this;
163 
164     class MockAbstract: T {
165         import std.conv: to;
166         import std.traits: Parameters, ReturnType;
167         import std.typecons: tuple;
168         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
169         mixin(implMixinStr!T);
170         mixin MockImplCommon;
171     }
172 
173     this(int/* force constructor*/) {
174         _impl = new MockAbstract;
175     }
176 
177     ~this() pure @safe {
178         if(!_verified) verify;
179     }
180 
181     void returnValue(string funcName, V...)(V values) {
182         return returnValue!(0, funcName)(values);
183     }
184 
185     /**
186        This version takes overloads into account. i is the overload
187        index. e.g.:
188        ---------
189        interface Interface { void foo(int); void foo(string); }
190        auto m = mock!Interface;
191        m.returnValue!(0, "foo"); // int overload
192        m.returnValue!(1, "foo"); // string overload
193        ---------
194      */
195     void returnValue(int i, string funcName, V...)(V values) {
196         import std.conv: text;
197         enum varName = funcName ~ text(`_`, i, `_returnValues`);
198         foreach(v; values)
199             mixin(varName ~ ` ~=  v;`);
200     }
201 }
202 
203 private string importsString(string module_, string[] Modules...) {
204     auto ret = `import ` ~ module_ ~ ";\n";
205     foreach(extraModule; Modules) {
206         ret ~= `import ` ~ extraModule ~ ";\n";
207     }
208     return ret;
209 }
210 
211 auto mock(T)() {
212     return Mock!T(0);
213 }
214 
215 
216 @("mock interface positive test no params")
217 @safe pure unittest {
218     interface Foo {
219         int foo(int, string) @safe pure;
220         void bar() @safe pure;
221     }
222 
223     int fun(Foo f) {
224         return 2 * f.foo(5, "foobar");
225     }
226 
227     auto m = mock!Foo;
228     m.expect!"foo";
229     fun(m);
230 }
231 
232 @("mock interface positive test with params")
233 @safe pure unittest {
234     import unit_threaded.asserts;
235 
236     interface Foo {
237         int foo(int, string) @safe pure;
238         void bar() @safe pure;
239     }
240 
241     int fun(Foo f) {
242         return 2 * f.foo(5, "foobar");
243     }
244 
245     {
246         auto m = mock!Foo;
247         m.expect!"foo"(5, "foobar");
248         fun(m);
249     }
250 
251     {
252         auto m = mock!Foo;
253         m.expect!"foo"(6, "foobar");
254         fun(m);
255         assertExceptionMsg(m.verify,
256                            `    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, "foobar")` ~ "\n" ~
257                            `    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(6, "foobar")`);
258     }
259 
260     {
261         auto m = mock!Foo;
262         m.expect!"foo"(5, "quux");
263         fun(m);
264         assertExceptionMsg(m.verify,
265                            `    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, "foobar")` ~ "\n" ~
266                            `    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(5, "quux")`);
267     }
268 }
269 
270 
271 @("mock interface negative test")
272 @safe pure unittest {
273     import unit_threaded.should;
274 
275     interface Foo {
276         int foo(int, string) @safe pure;
277     }
278 
279     auto m = mock!Foo;
280     m.expect!"foo";
281     m.verify.shouldThrowWithMessage("Expected nth 0 call to foo did not happen");
282 }
283 
284 // can't be in the unit test itself
285 version(unittest)
286 private class Class {
287     abstract int foo(int, string) @safe pure;
288     final int timesTwo(int i) @safe pure nothrow const { return i * 2; }
289     int timesThree(int i) @safe pure nothrow const { return i * 3; }
290 }
291 
292 @("mock class positive test")
293 @safe pure unittest {
294 
295     int fun(Class f) {
296         return 2 * f.foo(5, "foobar");
297     }
298 
299     auto m = mock!Class;
300     m.expect!"foo";
301     fun(m);
302 }
303 
304 
305 @("mock interface multiple calls")
306 @safe pure unittest {
307     interface Foo {
308         int foo(int, string) @safe pure;
309         int bar(int) @safe pure;
310     }
311 
312     void fun(Foo f) {
313         f.foo(3, "foo");
314         f.bar(5);
315         f.foo(4, "quux");
316     }
317 
318     auto m = mock!Foo;
319     m.expect!"foo"(3, "foo");
320     m.expect!"bar"(5);
321     m.expect!"foo"(4, "quux");
322     fun(m);
323     m.verify;
324 }
325 
326 @("interface expectCalled")
327 @safe pure unittest {
328     interface Foo {
329         int foo(int, string) @safe pure;
330         void bar() @safe pure;
331     }
332 
333     int fun(Foo f) {
334         return 2 * f.foo(5, "foobar");
335     }
336 
337     auto m = mock!Foo;
338     fun(m);
339     m.expectCalled!"foo"(5, "foobar");
340 }
341 
342 @("interface return value")
343 @safe pure unittest {
344     import unit_threaded.should;
345 
346     interface Foo {
347         int timesN(int i) @safe pure;
348     }
349 
350     int fun(Foo f) {
351         return f.timesN(3) * 2;
352     }
353 
354     auto m = mock!Foo;
355     m.returnValue!"timesN"(42);
356     immutable res = fun(m);
357     res.shouldEqual(84);
358 }
359 
360 @("interface return values")
361 @safe pure unittest {
362     import unit_threaded.should;
363 
364     interface Foo {
365         int timesN(int i) @safe pure;
366     }
367 
368     int fun(Foo f) {
369         return f.timesN(3) * 2;
370     }
371 
372     auto m = mock!Foo;
373     m.returnValue!"timesN"(42, 12);
374     fun(m).shouldEqual(84);
375     fun(m).shouldEqual(24);
376     fun(m).shouldEqual(0);
377 }
378 
379 struct ReturnValues(string function_, T...) if(allSatisfy!(isValue, T)) {
380     alias funcName = function_;
381     alias Values = T;
382 
383     static auto values() {
384         typeof(T[0])[] ret;
385         foreach(val; T) {
386             ret ~= val;
387         }
388         return ret;
389     }
390 }
391 
392 enum isReturnValue(alias T) = is(T: ReturnValues!U, U...);
393 enum isValue(alias T) = is(typeof(T));
394 
395 
396 /**
397    Version of mockStruct that accepts 0 or more values of the same
398    type. Whatever function is called on it, these values will
399    be returned one by one. The limitation is that if more than one
400    function is called on the mock, they all return the same type
401  */
402 auto mockStruct(T...)(auto ref T returns) {
403 
404     struct Mock {
405 
406         MockImpl* _impl;
407         alias _impl this;
408 
409         static struct MockImpl {
410 
411             static if(T.length > 0) {
412                 alias FirstType = typeof(returns[0]);
413                 private FirstType[] _returnValues;
414             }
415 
416             mixin MockImplCommon;
417 
418             auto opDispatch(string funcName, V...)(auto ref V values) {
419 
420                 import std.conv: to;
421                 import std.typecons: tuple;
422 
423                 calledFuncs ~= funcName;
424                 calledValues ~= tuple(values).to!string;
425 
426                 static if(T.length > 0) {
427 
428                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
429                     auto ret = _returnValues[0];
430                     _returnValues = _returnValues[1..$];
431                     return ret;
432                 }
433             }
434         }
435     }
436 
437     Mock m;
438     m._impl = new Mock.MockImpl;
439     static if(T.length > 0) {
440         foreach(r; returns)
441             m._impl._returnValues ~= r;
442     }
443 
444     return m;
445 }
446 
447 /**
448    Version of mockStruct that accepts a compile-time mapping
449    of function name to return values. Each template parameter
450    must be a value of type `ReturnValues`
451  */
452 
453 auto mockStruct(T...)() if(T.length > 0 && allSatisfy!(isReturnValue, T)) {
454 
455     struct Mock {
456         mixin MockImplCommon;
457 
458         int[string] _retIndices;
459 
460         auto opDispatch(string funcName, V...)(auto ref V values) {
461 
462             import std.conv: to;
463             import std.typecons: tuple;
464 
465             calledFuncs ~= funcName;
466             calledValues ~= tuple(values).to!string;
467 
468             foreach(retVal; T) {
469                 static if(retVal.funcName == funcName) {
470                     return retVal.values[_retIndices[funcName]++];
471                 }
472             }
473         }
474 
475         auto lefoofoo() {
476             return T[0].values[_retIndices["greet"]++];
477         }
478 
479     }
480 
481     Mock mock;
482 
483     foreach(retVal; T) {
484         mock._retIndices[retVal.funcName] = 0;
485     }
486 
487     return mock;
488 }
489 
490 
491 @("mock struct positive")
492 @safe pure unittest {
493     void fun(T)(T t) {
494         t.foobar;
495     }
496     auto m = mockStruct;
497     m.expect!"foobar";
498     fun(m);
499     m.verify;
500 }
501 
502 @("mock struct negative")
503 @safe pure unittest {
504     import unit_threaded.asserts;
505 
506     auto m = mockStruct;
507     m.expect!"foobar";
508     assertExceptionMsg(m.verify,
509                        "    source/unit_threaded/mock.d:123 - Expected nth 0 call to foobar did not happen\n");
510 
511 }
512 
513 
514 @("mock struct values positive")
515 @safe pure unittest {
516     void fun(T)(T t) {
517         t.foobar(2, "quux");
518     }
519 
520     auto m = mockStruct;
521     m.expect!"foobar"(2, "quux");
522     fun(m);
523     m.verify;
524 }
525 
526 @("mock struct values negative")
527 @safe pure unittest {
528     import unit_threaded.asserts;
529 
530     void fun(T)(T t) {
531         t.foobar(2, "quux");
532     }
533 
534     auto m = mockStruct;
535     m.expect!"foobar"(3, "quux");
536     fun(m);
537     assertExceptionMsg(m.verify,
538                        "    source/unit_threaded/mock.d:123 - foobar was called with unexpected Tuple!(int, string)(2, \"quux\")\n" ~
539                        "    source/unit_threaded/mock.d:123 -           instead of the expected Tuple!(int, string)(3, \"quux\")");
540 }
541 
542 
543 @("struct return value")
544 @safe pure unittest {
545     import unit_threaded.should;
546 
547     int fun(T)(T f) {
548         return f.timesN(3) * 2;
549     }
550 
551     auto m = mockStruct(42, 12);
552     fun(m).shouldEqual(84);
553     fun(m).shouldEqual(24);
554     fun(m).shouldEqual(0);
555     m.expectCalled!"timesN";
556 }
557 
558 @("struct expectCalled")
559 @safe pure unittest {
560     void fun(T)(T t) {
561         t.foobar(2, "quux");
562     }
563 
564     auto m = mockStruct;
565     fun(m);
566     m.expectCalled!"foobar"(2, "quux");
567 }
568 
569 @("mockStruct different return types for different functions")
570 @safe pure unittest {
571     import unit_threaded.should: shouldEqual;
572     auto m = mockStruct!(ReturnValues!("length", 5),
573                          ReturnValues!("greet", "hello"));
574     m.length.shouldEqual(5);
575     m.greet("bar").shouldEqual("hello");
576     m.expectCalled!"length";
577     m.expectCalled!"greet"("bar");
578 }
579 
580 @("mockStruct different return types for different functions and multiple return values")
581 @safe pure unittest {
582     import unit_threaded.should: shouldEqual;
583     auto m = mockStruct!(ReturnValues!("length", 5, 3),
584                          ReturnValues!("greet", "hello", "g'day"));
585     m.length.shouldEqual(5);
586     m.expectCalled!"length";
587     m.length.shouldEqual(3);
588     m.expectCalled!"length";
589 
590     m.greet("bar").shouldEqual("hello");
591     m.expectCalled!"greet"("bar");
592     m.greet("quux").shouldEqual("g'day");
593     m.expectCalled!"greet"("quux");
594 }
595 
596 
597 @("const(ubyte)[] return type]")
598 @safe pure unittest {
599     interface Interface {
600         const(ubyte)[] fun();
601     }
602 
603     auto m = mock!Interface;
604 }
605 
606 @("safe pure nothrow")
607 @safe pure unittest {
608     interface Interface {
609         int twice(int i) @safe pure nothrow /*@nogc*/;
610     }
611     auto m = mock!Interface;
612 }
613 
614 @("issue 63")
615 @safe pure unittest {
616     import unit_threaded.should;
617 
618     interface InterfaceWithOverloads {
619         int func(int) @safe pure;
620         int func(string) @safe pure;
621     }
622     alias ov = Identity!(__traits(allMembers, InterfaceWithOverloads)[0]);
623     auto m = mock!InterfaceWithOverloads;
624     m.returnValue!(0, "func")(3); // int overload
625     m.returnValue!(1, "func")(7); // string overload
626     m.expect!"func"("foo");
627     m.func("foo").shouldEqual(7);
628     m.verify;
629 }
630 
631 
632 auto throwStruct(E = UnitTestException, R = void)() {
633 
634     struct Mock {
635 
636         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)
637                     (auto ref V values) {
638             throw new E(funcName ~ " was called", file, line);
639         }
640     }
641 
642     return Mock();
643 }
644 
645 @("throwStruct default")
646 @safe pure unittest {
647     import unit_threaded.should: shouldThrow;
648     auto m = throwStruct;
649     m.foo.shouldThrow!UnitTestException;
650     m.bar(1, "foo").shouldThrow!UnitTestException;
651 }
652 
653 version(testing_unit_threaded) {
654     class FooException: Exception {
655         import std.exception: basicExceptionCtors;
656         mixin basicExceptionCtors;
657     }
658 
659 
660     @("throwStruct custom")
661         @safe pure unittest {
662         import unit_threaded.should: shouldThrow;
663 
664         auto m = throwStruct!FooException;
665         m.foo.shouldThrow!FooException;
666         m.bar(1, "foo").shouldThrow!FooException;
667     }
668 }
669 
670 
671 @("throwStruct return value type")
672 @safe pure unittest {
673     import unit_threaded.asserts;
674     auto m = throwStruct!(UnitTestException, int);
675     int i;
676     assertExceptionMsg(i = m.foo,
677                        "    source/unit_threaded/mock.d:123 - foo was called");
678     assertExceptionMsg(i = m.bar,
679                        "    source/unit_threaded/mock.d:123 - bar was called");
680 }