1 module unit_threaded.mock;
2 
3 import std.traits;
4 
5 alias Identity(alias T) = T;
6 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
7 
8 
9 string implMixinStr(T)() {
10     import std.array: join;
11     import std.format : format;
12     import std.range : iota;
13     import std.traits: functionAttributes, FunctionAttribute, Parameters, arity;
14     import std.conv: text;
15 
16     string[] lines;
17 
18     string getOverload(in string memberName, in int i) {
19         return `Identity!(__traits(getOverloads, T, "%s")[%s])`
20             .format(memberName, i);
21     }
22 
23     foreach(memberName; __traits(allMembers, T)) {
24 
25         static if(!isPrivate!(T, memberName)) {
26 
27             alias member = Identity!(__traits(getMember, T, memberName));
28 
29             static if(__traits(isAbstractFunction, member)) {
30                 foreach(i, overload; __traits(getOverloads, T, memberName)) {
31 
32                     enum overloadName = text(memberName, "_", i);
33 
34                     enum overloadString = getOverload(memberName, i);
35                     lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, overloadString);
36                     lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, overloadString);
37 
38                     static if(functionAttributes!member & FunctionAttribute.nothrow_)
39                         enum tryIndent = "    ";
40                     else
41                         enum tryIndent = "";
42 
43                     static if(is(ReturnType!member == void))
44                         enum returnDefault = "";
45                     else {
46                         enum varName = overloadName ~ `_returnValues`;
47                         lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
48                         lines ~= "";
49                         enum returnDefault = [`    if(` ~ varName ~ `.length > 0) {`,
50                                               `        auto ret = ` ~ varName ~ `[0];`,
51                                               `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
52                                               `        return ret;`,
53                                               `    } else`,
54                                               `        return %s_returnType.init;`.format(overloadName)];
55                     }
56 
57                     lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~
58                         typeAndArgsParens!(Parameters!member)(overloadName) ~ ` {`;
59 
60                     static if(functionAttributes!member & FunctionAttribute.nothrow_)
61                         lines ~= "try {";
62 
63                     lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
64                     lines ~= tryIndent ~ `    calledValues ~= tuple` ~ argNamesParens(arity!member) ~ `.to!string;`;
65 
66                     static if(functionAttributes!member & FunctionAttribute.nothrow_)
67                         lines ~= "    } catch(Exception) {}";
68 
69                     lines ~= returnDefault;
70 
71                     lines ~= `}`;
72                     lines ~= "";
73                 }
74             }
75         }
76     }
77 
78     return lines.join("\n");
79 }
80 
81 private string argNamesParens(int N) @safe pure {
82     return "(" ~ argNames(N) ~ ")";
83 }
84 
85 private string argNames(int N) @safe pure {
86     import std.range;
87     import std.algorithm;
88     import std.conv;
89     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
90 }
91 
92 private string typeAndArgsParens(T...)(string prefix) {
93     import std.array;
94     import std.conv;
95     import std.format : format;
96     string[] parts;
97     foreach(i, t; T)
98         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
99     return "(" ~ parts.join(", ") ~ ")";
100 }
101 
102 mixin template MockImplCommon() {
103     bool verified;
104     string[] expectedFuncs;
105     string[] calledFuncs;
106     string[] expectedValues;
107     string[] calledValues;
108 
109     void expect(string funcName, V...)(auto ref V values) @safe pure {
110         import std.conv: to;
111         import std.typecons: tuple;
112 
113         expectedFuncs ~= funcName;
114         static if(V.length > 0)
115             expectedValues ~= tuple(values).to!string;
116         else
117             expectedValues ~= "";
118     }
119 
120     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(auto ref V values) {
121         expect!func(values);
122         verify(file, line);
123     }
124 
125     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
126         import std.range: repeat, take, join;
127         import std.conv: to;
128         import unit_threaded.should: fail, UnitTestException;
129 
130 
131         if(verified)
132             fail("Mock already verified", file, line);
133 
134         verified = true;
135 
136         for(int i = 0; i < expectedFuncs.length; ++i) {
137 
138             if(i >= calledFuncs.length)
139                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", file, line);
140 
141             if(expectedFuncs[i] != calledFuncs[i])
142                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " but got " ~ calledFuncs[i] ~
143                      " instead",
144                      file, line);
145 
146             if(expectedValues[i] != calledValues[i] && expectedValues[i] != "")
147                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
148                                              " ".repeat.take(expectedFuncs[i].length + 4).join ~
149                                              "instead of the expected " ~ expectedValues[i]] ,
150                                             file, line);
151         }
152     }
153 }
154 
155 private enum isString(alias T) = is(typeof(T) == string);
156 
157 struct Mock(T) {
158 
159     MockAbstract _impl;
160     alias _impl this;
161 
162     class MockAbstract: T {
163         import std.conv: to;
164         import std.traits: Parameters, ReturnType;
165         import std.typecons: tuple;
166         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
167         mixin(implMixinStr!T);
168         mixin MockImplCommon;
169     }
170 
171     this(int/* force constructor*/) {
172         _impl = new MockAbstract;
173     }
174 
175     ~this() pure @safe {
176         if(!verified) verify;
177     }
178 
179     void returnValue(string funcName, V...)(V values) {
180         return returnValue!(0, funcName)(values);
181     }
182 
183     /**
184        This version takes overloads into account. i is the overload
185        index. e.g.:
186        ---------
187        interface Interface { void foo(int); void foo(string); }
188        auto m = mock!Interface;
189        m.returnValue!(0, "foo"); // int overload
190        m.returnValue!(1, "foo"); // string overload
191        ---------
192      */
193     void returnValue(int i, string funcName, V...)(V values) {
194         import std.conv: text;
195         enum varName = funcName ~ text(`_`, i, `_returnValues`);
196         foreach(v; values)
197             mixin(varName ~ ` ~=  v;`);
198     }
199 }
200 
201 private string importsString(string module_, string[] Modules...) {
202     auto ret = `import ` ~ module_ ~ ";\n";
203     foreach(extraModule; Modules) {
204         ret ~= `import ` ~ extraModule ~ ";\n";
205     }
206     return ret;
207 }
208 
209 auto mock(T)() {
210     return Mock!T(0);
211 }
212 
213 
214 @("mock interface positive test no params")
215 @safe pure unittest {
216     interface Foo {
217         int foo(int, string) @safe pure;
218         void bar() @safe pure;
219     }
220 
221     int fun(Foo f) {
222         return 2 * f.foo(5, "foobar");
223     }
224 
225     auto m = mock!Foo;
226     m.expect!"foo";
227     fun(m);
228 }
229 
230 @("mock interface positive test with params")
231 @safe pure unittest {
232     import unit_threaded.asserts;
233 
234     interface Foo {
235         int foo(int, string) @safe pure;
236         void bar() @safe pure;
237     }
238 
239     int fun(Foo f) {
240         return 2 * f.foo(5, "foobar");
241     }
242 
243     {
244         auto m = mock!Foo;
245         m.expect!"foo"(5, "foobar");
246         fun(m);
247     }
248 
249     {
250         auto m = mock!Foo;
251         m.expect!"foo"(6, "foobar");
252         fun(m);
253         assertExceptionMsg(m.verify,
254                            "    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, \"foobar\")\n" ~
255                            "    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(6, \"foobar\")");
256     }
257 
258     {
259         auto m = mock!Foo;
260         m.expect!"foo"(5, "quux");
261         fun(m);
262         assertExceptionMsg(m.verify,
263                            "    source/unit_threaded/mock.d:123 - foo was called with unexpected Tuple!(int, string)(5, \"foobar\")\n" ~
264                            "    source/unit_threaded/mock.d:123 -        instead of the expected Tuple!(int, string)(5, \"quux\")");
265     }
266 }
267 
268 
269 @("mock interface negative test")
270 @safe pure unittest {
271     import unit_threaded.should;
272 
273     interface Foo {
274         int foo(int, string) @safe pure;
275     }
276 
277     auto m = mock!Foo;
278     m.expect!"foo";
279     m.verify.shouldThrowWithMessage("Expected nth 0 call to foo did not happen");
280 }
281 
282 // can't be in the unit test itself
283 version(unittest)
284 private class Class {
285     abstract int foo(int, string) @safe pure;
286     final int timesTwo(int i) @safe pure nothrow const { return i * 2; }
287     int timesThree(int i) @safe pure nothrow const { return i * 3; }
288 }
289 
290 @("mock class positive test")
291 @safe pure unittest {
292 
293     int fun(Class f) {
294         return 2 * f.foo(5, "foobar");
295     }
296 
297     auto m = mock!Class;
298     m.expect!"foo";
299     fun(m);
300 }
301 
302 
303 @("mock interface multiple calls")
304 @safe pure unittest {
305     interface Foo {
306         int foo(int, string) @safe pure;
307         int bar(int) @safe pure;
308     }
309 
310     void fun(Foo f) {
311         f.foo(3, "foo");
312         f.bar(5);
313         f.foo(4, "quux");
314     }
315 
316     auto m = mock!Foo;
317     m.expect!"foo"(3, "foo");
318     m.expect!"bar"(5);
319     m.expect!"foo"(4, "quux");
320     fun(m);
321     m.verify;
322 }
323 
324 @("interface expectCalled")
325 @safe pure unittest {
326     interface Foo {
327         int foo(int, string) @safe pure;
328         void bar() @safe pure;
329     }
330 
331     int fun(Foo f) {
332         return 2 * f.foo(5, "foobar");
333     }
334 
335     auto m = mock!Foo;
336     fun(m);
337     m.expectCalled!"foo"(5, "foobar");
338 }
339 
340 @("interface return value")
341 @safe pure unittest {
342     import unit_threaded.should;
343 
344     interface Foo {
345         int timesN(int i) @safe pure;
346     }
347 
348     int fun(Foo f) {
349         return f.timesN(3) * 2;
350     }
351 
352     auto m = mock!Foo;
353     m.returnValue!"timesN"(42);
354     immutable res = fun(m);
355     res.shouldEqual(84);
356 }
357 
358 @("interface return values")
359 @safe pure unittest {
360     import unit_threaded.should;
361 
362     interface Foo {
363         int timesN(int i) @safe pure;
364     }
365 
366     int fun(Foo f) {
367         return f.timesN(3) * 2;
368     }
369 
370     auto m = mock!Foo;
371     m.returnValue!"timesN"(42, 12);
372     fun(m).shouldEqual(84);
373     fun(m).shouldEqual(24);
374     fun(m).shouldEqual(0);
375 }
376 
377 
378 auto mockStruct(T...)(auto ref T returns) {
379 
380     struct Mock {
381 
382         MockImpl* _impl;
383         alias _impl this;
384 
385         static struct MockImpl {
386 
387             static if(T.length > 0) {
388                 alias FirstType = typeof(returns[0]);
389                 private FirstType[] _returnValues;
390             }
391 
392             mixin MockImplCommon;
393 
394             auto opDispatch(string funcName, V...)(V values) {
395                 import std.conv: to;
396                 import std.typecons: tuple;
397                 calledFuncs ~= funcName;
398                 calledValues ~= tuple(values).to!string;
399                 static if(T.length > 0) {
400                     if(_returnValues.length == 0) return typeof(_returnValues[0]).init;
401                     auto ret = _returnValues[0];
402                      _returnValues = _returnValues[1..$];
403                     return ret;
404                 }
405             }
406         }
407     }
408 
409     Mock m;
410     m._impl = new Mock.MockImpl;
411     static if(T.length > 0)
412         foreach(r; returns)
413             m._impl._returnValues ~= r;
414     return m;
415 }
416 
417 
418 @("mock struct positive")
419 @safe pure unittest {
420     void fun(T)(T t) {
421         t.foobar;
422     }
423     auto m = mockStruct;
424     m.expect!"foobar";
425     fun(m);
426     m.verify;
427 }
428 
429 @("mock struct negative")
430 @safe pure unittest {
431     import unit_threaded.asserts;
432 
433     auto m = mockStruct;
434     m.expect!"foobar";
435     assertExceptionMsg(m.verify,
436                        "    source/unit_threaded/mock.d:123 - Expected nth 0 call to foobar did not happen\n");
437 
438 }
439 
440 
441 @("mock struct values positive")
442 @safe pure unittest {
443     void fun(T)(T t) {
444         t.foobar(2, "quux");
445     }
446 
447     auto m = mockStruct;
448     m.expect!"foobar"(2, "quux");
449     fun(m);
450     m.verify;
451 }
452 
453 @("mock struct values negative")
454 @safe pure unittest {
455     import unit_threaded.asserts;
456 
457     void fun(T)(T t) {
458         t.foobar(2, "quux");
459     }
460 
461     auto m = mockStruct;
462     m.expect!"foobar"(3, "quux");
463     fun(m);
464     assertExceptionMsg(m.verify,
465                        "    source/unit_threaded/mock.d:123 - foobar was called with unexpected Tuple!(int, string)(2, \"quux\")\n" ~
466                        "    source/unit_threaded/mock.d:123 -           instead of the expected Tuple!(int, string)(3, \"quux\")");
467 }
468 
469 
470 @("struct return value")
471 @safe pure unittest {
472     import unit_threaded.should;
473 
474     int fun(T)(T f) {
475         return f.timesN(3) * 2;
476     }
477 
478     auto m = mockStruct(42, 12);
479     fun(m).shouldEqual(84);
480     fun(m).shouldEqual(24);
481     fun(m).shouldEqual(0);
482     m.expectCalled!"timesN";
483 }
484 
485 @("struct expectCalled")
486 @safe pure unittest {
487     void fun(T)(T t) {
488         t.foobar(2, "quux");
489     }
490 
491     auto m = mockStruct;
492     fun(m);
493     m.expectCalled!"foobar"(2, "quux");
494 }
495 
496 
497 @("const(ubyte)[] return type]")
498 @safe pure unittest {
499     interface Interface {
500         const(ubyte)[] fun();
501     }
502 
503     auto m = mock!Interface;
504 }
505 
506 @("safe pure nothrow")
507 @safe pure unittest {
508     interface Interface {
509         int twice(int i) @safe pure nothrow /*@nogc*/;
510     }
511     auto m = mock!Interface;
512 }
513 
514 @("issue 63")
515 @safe pure unittest {
516     import unit_threaded.should;
517 
518     interface InterfaceWithOverloads {
519         int func(int) @safe pure;
520         int func(string) @safe pure;
521     }
522     alias ov = Identity!(__traits(allMembers, InterfaceWithOverloads)[0]);
523     auto m = mock!InterfaceWithOverloads;
524     m.returnValue!(0, "func")(3); // int overload
525     m.returnValue!(1, "func")(7); // string overload
526     m.expect!"func"("foo");
527     m.func("foo").shouldEqual(7);
528     m.verify;
529 }